Spaces:
Sleeping
Sleeping
| import os | |
| import time | |
| import random | |
| import gc | |
| from typing import Optional, Tuple | |
| import gradio as gr | |
| import numpy as np | |
| import torch | |
| from PIL import Image, ImageChops, ImageFilter, ImageOps | |
| from diffusers import AutoPipelineForInpainting, LCMScheduler, AutoencoderTiny | |
| MODEL_CHOICES = { | |
| "DreamShaper 8 Inpainting": "Lykon/dreamshaper-8-inpainting", | |
| "Official SD1.5 Inpainting": "stable-diffusion-v1-5/stable-diffusion-inpainting", | |
| } | |
| DEFAULT_MODEL_LABEL = "DreamShaper 8 Inpainting" | |
| LCM_LORA_ID = os.getenv( | |
| "LCM_LORA_ID", | |
| "latent-consistency/lcm-lora-sdv1-5", | |
| ) | |
| TINY_VAE_ID = os.getenv( | |
| "TINY_VAE_ID", | |
| "madebyollin/taesd", | |
| ) | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| DTYPE = torch.float16 if DEVICE == "cuda" else torch.float32 | |
| PIPE = None | |
| PIPE_MODEL_ID = None | |
| def to_pil(x) -> Optional[Image.Image]: | |
| if x is None: | |
| return None | |
| if isinstance(x, Image.Image): | |
| return x | |
| if isinstance(x, np.ndarray): | |
| if x.dtype != np.uint8: | |
| x = np.clip(x, 0, 255).astype(np.uint8) | |
| return Image.fromarray(x) | |
| return None | |
| def resolve_model_id(model_label: str) -> str: | |
| return MODEL_CHOICES.get(model_label, MODEL_CHOICES[DEFAULT_MODEL_LABEL]) | |
| def unload_pipe(): | |
| global PIPE, PIPE_MODEL_ID | |
| if PIPE is not None: | |
| del PIPE | |
| PIPE = None | |
| PIPE_MODEL_ID = None | |
| gc.collect() | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| def load_pipe(model_label: str): | |
| global PIPE, PIPE_MODEL_ID | |
| model_id = resolve_model_id(model_label) | |
| if PIPE is not None and PIPE_MODEL_ID == model_id: | |
| return PIPE | |
| unload_pipe() | |
| pipe = AutoPipelineForInpainting.from_pretrained( | |
| model_id, | |
| torch_dtype=DTYPE, | |
| safety_checker=None, | |
| requires_safety_checker=False, | |
| ) | |
| pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config) | |
| pipe.load_lora_weights(LCM_LORA_ID) | |
| try: | |
| pipe.fuse_lora() | |
| except Exception: | |
| pass | |
| pipe.vae = AutoencoderTiny.from_pretrained( | |
| TINY_VAE_ID, | |
| torch_dtype=DTYPE, | |
| ) | |
| pipe = pipe.to(DEVICE) | |
| pipe.set_progress_bar_config(disable=True) | |
| try: | |
| pipe.enable_attention_slicing() | |
| except Exception: | |
| pass | |
| PIPE = pipe | |
| PIPE_MODEL_ID = model_id | |
| return PIPE | |
| def extract_image_and_mask(editor_value) -> Tuple[Image.Image, Image.Image]: | |
| if editor_value is None: | |
| raise gr.Error("Upload an image and draw a mask first.") | |
| if isinstance(editor_value, dict): | |
| background = to_pil(editor_value.get("background")) | |
| composite = to_pil(editor_value.get("composite")) | |
| layers = editor_value.get("layers") or [] | |
| else: | |
| background = to_pil(editor_value) | |
| composite = None | |
| layers = [] | |
| if background is None: | |
| raise gr.Error("No base image found. Upload an image first.") | |
| image = background.convert("RGB") | |
| mask = None | |
| for layer in layers: | |
| layer_img = to_pil(layer) | |
| if layer_img is None: | |
| continue | |
| if layer_img.size != image.size: | |
| layer_img = layer_img.resize(image.size, Image.Resampling.NEAREST) | |
| if layer_img.mode != "RGBA": | |
| layer_img = layer_img.convert("RGBA") | |
| alpha = layer_img.getchannel("A") | |
| if mask is None: | |
| mask = alpha | |
| else: | |
| mask = ImageChops.lighter(mask, alpha) | |
| if mask is None and composite is not None: | |
| composite = composite.convert("RGB") | |
| if composite.size != image.size: | |
| composite = composite.resize(image.size, Image.Resampling.NEAREST) | |
| diff = ImageChops.difference(image, composite).convert("L") | |
| mask = diff.point(lambda p: 255 if p > 12 else 0) | |
| if mask is None: | |
| raise gr.Error("Draw over the area you want to repaint.") | |
| mask = mask.convert("L") | |
| if np.array(mask).max() < 10: | |
| raise gr.Error("Mask is empty. Draw over the area you want to repaint.") | |
| return image, mask | |
| def mask_bbox(mask: Image.Image, threshold: int = 10) -> Tuple[int, int, int, int]: | |
| arr = np.array(mask.convert("L")) | |
| ys, xs = np.where(arr > threshold) | |
| if len(xs) == 0 or len(ys) == 0: | |
| raise gr.Error("Mask is empty. Draw over the area you want to repaint.") | |
| left = int(xs.min()) | |
| right = int(xs.max()) + 1 | |
| top = int(ys.min()) | |
| bottom = int(ys.max()) + 1 | |
| return left, top, right, bottom | |
| def clamp_bbox( | |
| bbox: Tuple[int, int, int, int], | |
| image_size: Tuple[int, int], | |
| ) -> Tuple[int, int, int, int]: | |
| w, h = image_size | |
| left, top, right, bottom = bbox | |
| left = max(0, min(left, w)) | |
| right = max(0, min(right, w)) | |
| top = max(0, min(top, h)) | |
| bottom = max(0, min(bottom, h)) | |
| return left, top, right, bottom | |
| def expand_bbox( | |
| bbox: Tuple[int, int, int, int], | |
| image_size: Tuple[int, int], | |
| padding: int, | |
| min_side: int, | |
| ) -> Tuple[int, int, int, int]: | |
| image_w, image_h = image_size | |
| left, top, right, bottom = bbox | |
| left -= padding | |
| top -= padding | |
| right += padding | |
| bottom += padding | |
| left, top, right, bottom = clamp_bbox( | |
| (left, top, right, bottom), | |
| image_size, | |
| ) | |
| crop_w = right - left | |
| crop_h = bottom - top | |
| if crop_w < min_side: | |
| extra = min_side - crop_w | |
| left -= extra // 2 | |
| right += extra - extra // 2 | |
| if crop_h < min_side: | |
| extra = min_side - crop_h | |
| top -= extra // 2 | |
| bottom += extra - extra // 2 | |
| if left < 0: | |
| right -= left | |
| left = 0 | |
| if top < 0: | |
| bottom -= top | |
| top = 0 | |
| if right > image_w: | |
| shift = right - image_w | |
| left -= shift | |
| right = image_w | |
| if bottom > image_h: | |
| shift = bottom - image_h | |
| top -= shift | |
| bottom = image_h | |
| left, top, right, bottom = clamp_bbox( | |
| (left, top, right, bottom), | |
| image_size, | |
| ) | |
| return left, top, right, bottom | |
| def pad_to_multiple_of_8( | |
| image: Image.Image, | |
| mask: Image.Image, | |
| ) -> Tuple[Image.Image, Image.Image, Tuple[int, int], Tuple[int, int]]: | |
| original_w, original_h = image.size | |
| padded_w = ((original_w + 7) // 8) * 8 | |
| padded_h = ((original_h + 7) // 8) * 8 | |
| pad_w = padded_w - original_w | |
| pad_h = padded_h - original_h | |
| if pad_w == 0 and pad_h == 0: | |
| return image, mask, (original_w, original_h), (padded_w, padded_h) | |
| image = ImageOps.expand( | |
| image, | |
| border=(0, 0, pad_w, pad_h), | |
| fill=0, | |
| ) | |
| mask = ImageOps.expand( | |
| mask, | |
| border=(0, 0, pad_w, pad_h), | |
| fill=0, | |
| ) | |
| return image, mask, (original_w, original_h), (padded_w, padded_h) | |
| def make_crop_inputs( | |
| image: Image.Image, | |
| mask: Image.Image, | |
| crop_padding: int, | |
| min_crop_side: int, | |
| ): | |
| raw_bbox = mask_bbox(mask) | |
| crop_bbox = expand_bbox( | |
| raw_bbox, | |
| image.size, | |
| padding=int(crop_padding), | |
| min_side=int(min_crop_side), | |
| ) | |
| crop_image = image.crop(crop_bbox) | |
| crop_mask = mask.crop(crop_bbox) | |
| padded_image, padded_mask, crop_size, padded_size = pad_to_multiple_of_8( | |
| crop_image, | |
| crop_mask, | |
| ) | |
| return { | |
| "raw_bbox": raw_bbox, | |
| "crop_bbox": crop_bbox, | |
| "crop_image": crop_image, | |
| "crop_mask": crop_mask, | |
| "padded_image": padded_image, | |
| "padded_mask": padded_mask, | |
| "crop_size": crop_size, | |
| "padded_size": padded_size, | |
| } | |
| def run_inpaint( | |
| editor_value, | |
| model_label: str, | |
| prompt: str, | |
| crop_padding: int, | |
| work_mode: str, | |
| min_working_window: int, | |
| steps: int, | |
| guidance_scale: float, | |
| strength: float, | |
| seed: int, | |
| random_seed: bool, | |
| mask_blur: int, | |
| paste_whole_crop: bool, | |
| ): | |
| if not prompt or not prompt.strip(): | |
| raise gr.Error("Write a prompt for the masked area.") | |
| model_id = resolve_model_id(model_label) | |
| pipe = load_pipe(model_label) | |
| original_image, original_mask = extract_image_and_mask(editor_value) | |
| original_image = original_image.convert("RGB") | |
| original_mask = original_mask.convert("L") | |
| if mask_blur > 0: | |
| generation_mask_source = original_mask.filter( | |
| ImageFilter.GaussianBlur(radius=int(mask_blur)) | |
| ) | |
| else: | |
| generation_mask_source = original_mask | |
| if work_mode == "Crop around mask": | |
| crop_info = make_crop_inputs( | |
| original_image, | |
| generation_mask_source, | |
| crop_padding=int(crop_padding), | |
| min_crop_side=int(min_working_window), | |
| ) | |
| gen_image = crop_info["padded_image"] | |
| gen_mask = crop_info["padded_mask"] | |
| crop_bbox = crop_info["crop_bbox"] | |
| crop_size = crop_info["crop_size"] | |
| padded_size = crop_info["padded_size"] | |
| else: | |
| crop_info = None | |
| gen_image, gen_mask, crop_size, padded_size = pad_to_multiple_of_8( | |
| original_image, | |
| generation_mask_source, | |
| ) | |
| crop_bbox = (0, 0, original_image.size[0], original_image.size[1]) | |
| if random_seed: | |
| seed = random.randint(0, 2**31 - 1) | |
| generator = torch.Generator(device=DEVICE).manual_seed(int(seed)) | |
| start = time.perf_counter() | |
| with torch.inference_mode(): | |
| generated = pipe( | |
| prompt=prompt.strip(), | |
| image=gen_image, | |
| mask_image=gen_mask, | |
| height=gen_image.size[1], | |
| width=gen_image.size[0], | |
| num_inference_steps=int(steps), | |
| guidance_scale=float(guidance_scale), | |
| strength=float(strength), | |
| generator=generator, | |
| ).images[0].convert("RGB") | |
| elapsed = time.perf_counter() - start | |
| if padded_size != crop_size: | |
| generated = generated.crop((0, 0, crop_size[0], crop_size[1])) | |
| if work_mode == "Crop around mask": | |
| left, top, right, bottom = crop_bbox | |
| final_result = original_image.copy() | |
| if paste_whole_crop: | |
| final_result.paste(generated, (left, top)) | |
| paste_mode = "whole generated crop" | |
| else: | |
| final_mask = original_mask.crop(crop_bbox) | |
| if mask_blur > 0: | |
| final_mask = final_mask.filter( | |
| ImageFilter.GaussianBlur(radius=int(mask_blur)) | |
| ) | |
| final_result.paste(generated, (left, top), final_mask) | |
| paste_mode = "masked area only" | |
| crop_report = ( | |
| f"- raw mask bbox: `{crop_info['raw_bbox']}`\n" | |
| f"- working crop bbox: `{crop_bbox}`\n" | |
| f"- working crop size: `{crop_size[0]}x{crop_size[1]}`\n" | |
| f"- generation size: `{padded_size[0]}x{padded_size[1]}`\n" | |
| f"- paste mode: `{paste_mode}`" | |
| ) | |
| else: | |
| if paste_whole_crop: | |
| final_result = generated | |
| paste_mode = "whole generated image" | |
| else: | |
| final_mask = original_mask | |
| if mask_blur > 0: | |
| final_mask = final_mask.filter( | |
| ImageFilter.GaussianBlur(radius=int(mask_blur)) | |
| ) | |
| final_result = Image.composite(generated, original_image, final_mask) | |
| paste_mode = "masked area only" | |
| crop_report = ( | |
| "- working crop: `none`\n" | |
| f"- original size: `{original_image.size[0]}x{original_image.size[1]}`\n" | |
| f"- generation size: `{padded_size[0]}x{padded_size[1]}`\n" | |
| f"- paste mode: `{paste_mode}`" | |
| ) | |
| padding_used = "yes" if padded_size != crop_size else "no" | |
| info = ( | |
| "**Done**\n\n" | |
| f"- device: `{DEVICE}`\n" | |
| f"- selected model: `{model_label}`\n" | |
| f"- model id: `{model_id}`\n" | |
| "- speed trick: `LCM-LoRA`\n" | |
| "- vae: `TAESD`\n" | |
| f"- work mode: `{work_mode}`\n" | |
| "- resize: `none`\n" | |
| f"- context around mask: `{crop_padding}` px\n" | |
| f"- minimum working window: `{min_working_window}` px\n" | |
| f"- padding to multiple of 8: `{padding_used}`\n" | |
| f"{crop_report}\n" | |
| f"- steps: `{steps}`\n" | |
| f"- guidance: `{guidance_scale}`\n" | |
| f"- strength: `{strength}`\n" | |
| f"- mask blur: `{mask_blur}`\n" | |
| f"- seed: `{seed}`\n" | |
| f"- time: `{elapsed:.1f}s`\n\n" | |
| "This version does not downscale or upscale. " | |
| "In crop mode it sends the mask bbox plus surrounding context to the model, " | |
| "then pastes the result back into the original image." | |
| ) | |
| return final_result, info | |
| with gr.Blocks(title="SD15 Light Inpaint CPU") as demo: | |
| gr.Markdown( | |
| "# SD15 Light Inpaint CPU\n\n" | |
| "Upload an image, draw over the area you want to repaint, and describe what should appear there.\n\n" | |
| "Default mode uses a local crop around the mask instead of resizing the whole image. " | |
| "This keeps the model working at the original local visual scale while still reducing the amount of image sent to the model." | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| editor = gr.ImageEditor( | |
| label="Image + mask", | |
| type="pil", | |
| ) | |
| model_label = gr.Dropdown( | |
| label="Base inpaint model", | |
| choices=list(MODEL_CHOICES.keys()), | |
| value=DEFAULT_MODEL_LABEL, | |
| ) | |
| prompt = gr.Textbox( | |
| label="Prompt for masked area", | |
| value="a beautiful fantasy detail, coherent with the original image, natural lighting", | |
| lines=3, | |
| ) | |
| crop_padding = gr.Slider( | |
| label="Context around mask", | |
| minimum=32, | |
| maximum=384, | |
| step=32, | |
| value=128, | |
| ) | |
| with gr.Accordion("Advanced", open=False): | |
| work_mode = gr.Radio( | |
| label="Work mode", | |
| choices=[ | |
| "Crop around mask", | |
| "Full image no resize", | |
| ], | |
| value="Crop around mask", | |
| ) | |
| min_working_window = gr.Slider( | |
| label="Minimum working window size", | |
| minimum=128, | |
| maximum=768, | |
| step=64, | |
| value=384, | |
| ) | |
| paste_whole_crop = gr.Checkbox( | |
| label="Paste whole generated crop back, diagnostic", | |
| value=False, | |
| ) | |
| steps = gr.Slider( | |
| label="Steps", | |
| minimum=2, | |
| maximum=8, | |
| step=1, | |
| value=4, | |
| ) | |
| guidance_scale = gr.Slider( | |
| label="Guidance scale / CFG. LCM usually works best around 1.0-2.0", | |
| minimum=1.0, | |
| maximum=8.0, | |
| step=0.1, | |
| value=1.5, | |
| ) | |
| strength = gr.Slider( | |
| label="Strength", | |
| minimum=0.3, | |
| maximum=1.0, | |
| step=0.05, | |
| value=0.85, | |
| ) | |
| mask_blur = gr.Slider( | |
| label="Mask blur", | |
| minimum=0, | |
| maximum=16, | |
| step=1, | |
| value=0, | |
| ) | |
| seed = gr.Number( | |
| label="Seed", | |
| value=12345, | |
| precision=0, | |
| ) | |
| random_seed = gr.Checkbox( | |
| label="Random seed", | |
| value=False, | |
| ) | |
| button = gr.Button("Generate", variant="primary") | |
| with gr.Column(): | |
| output = gr.Image(label="Result", type="pil") | |
| info = gr.Markdown() | |
| button.click( | |
| fn=run_inpaint, | |
| inputs=[ | |
| editor, | |
| model_label, | |
| prompt, | |
| crop_padding, | |
| work_mode, | |
| min_working_window, | |
| steps, | |
| guidance_scale, | |
| strength, | |
| seed, | |
| random_seed, | |
| mask_blur, | |
| paste_whole_crop, | |
| ], | |
| outputs=[output, info], | |
| ) | |
| demo.queue(max_size=4).launch( | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| ssr_mode=False, | |
| ) |