MetricMogul's picture
Update app.py
0dc9237 verified
Raw
History Blame Contribute Delete
16.5 kB
import os
import time
import random
import gc
from typing import Optional, Tuple
import gradio as gr
import numpy as np
import torch
from PIL import Image, ImageChops, ImageFilter, ImageOps
from diffusers import AutoPipelineForInpainting, LCMScheduler, AutoencoderTiny
MODEL_CHOICES = {
"DreamShaper 8 Inpainting": "Lykon/dreamshaper-8-inpainting",
"Official SD1.5 Inpainting": "stable-diffusion-v1-5/stable-diffusion-inpainting",
}
DEFAULT_MODEL_LABEL = "DreamShaper 8 Inpainting"
LCM_LORA_ID = os.getenv(
"LCM_LORA_ID",
"latent-consistency/lcm-lora-sdv1-5",
)
TINY_VAE_ID = os.getenv(
"TINY_VAE_ID",
"madebyollin/taesd",
)
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DTYPE = torch.float16 if DEVICE == "cuda" else torch.float32
PIPE = None
PIPE_MODEL_ID = None
def to_pil(x) -> Optional[Image.Image]:
if x is None:
return None
if isinstance(x, Image.Image):
return x
if isinstance(x, np.ndarray):
if x.dtype != np.uint8:
x = np.clip(x, 0, 255).astype(np.uint8)
return Image.fromarray(x)
return None
def resolve_model_id(model_label: str) -> str:
return MODEL_CHOICES.get(model_label, MODEL_CHOICES[DEFAULT_MODEL_LABEL])
def unload_pipe():
global PIPE, PIPE_MODEL_ID
if PIPE is not None:
del PIPE
PIPE = None
PIPE_MODEL_ID = None
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
def load_pipe(model_label: str):
global PIPE, PIPE_MODEL_ID
model_id = resolve_model_id(model_label)
if PIPE is not None and PIPE_MODEL_ID == model_id:
return PIPE
unload_pipe()
pipe = AutoPipelineForInpainting.from_pretrained(
model_id,
torch_dtype=DTYPE,
safety_checker=None,
requires_safety_checker=False,
)
pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config)
pipe.load_lora_weights(LCM_LORA_ID)
try:
pipe.fuse_lora()
except Exception:
pass
pipe.vae = AutoencoderTiny.from_pretrained(
TINY_VAE_ID,
torch_dtype=DTYPE,
)
pipe = pipe.to(DEVICE)
pipe.set_progress_bar_config(disable=True)
try:
pipe.enable_attention_slicing()
except Exception:
pass
PIPE = pipe
PIPE_MODEL_ID = model_id
return PIPE
def extract_image_and_mask(editor_value) -> Tuple[Image.Image, Image.Image]:
if editor_value is None:
raise gr.Error("Upload an image and draw a mask first.")
if isinstance(editor_value, dict):
background = to_pil(editor_value.get("background"))
composite = to_pil(editor_value.get("composite"))
layers = editor_value.get("layers") or []
else:
background = to_pil(editor_value)
composite = None
layers = []
if background is None:
raise gr.Error("No base image found. Upload an image first.")
image = background.convert("RGB")
mask = None
for layer in layers:
layer_img = to_pil(layer)
if layer_img is None:
continue
if layer_img.size != image.size:
layer_img = layer_img.resize(image.size, Image.Resampling.NEAREST)
if layer_img.mode != "RGBA":
layer_img = layer_img.convert("RGBA")
alpha = layer_img.getchannel("A")
if mask is None:
mask = alpha
else:
mask = ImageChops.lighter(mask, alpha)
if mask is None and composite is not None:
composite = composite.convert("RGB")
if composite.size != image.size:
composite = composite.resize(image.size, Image.Resampling.NEAREST)
diff = ImageChops.difference(image, composite).convert("L")
mask = diff.point(lambda p: 255 if p > 12 else 0)
if mask is None:
raise gr.Error("Draw over the area you want to repaint.")
mask = mask.convert("L")
if np.array(mask).max() < 10:
raise gr.Error("Mask is empty. Draw over the area you want to repaint.")
return image, mask
def mask_bbox(mask: Image.Image, threshold: int = 10) -> Tuple[int, int, int, int]:
arr = np.array(mask.convert("L"))
ys, xs = np.where(arr > threshold)
if len(xs) == 0 or len(ys) == 0:
raise gr.Error("Mask is empty. Draw over the area you want to repaint.")
left = int(xs.min())
right = int(xs.max()) + 1
top = int(ys.min())
bottom = int(ys.max()) + 1
return left, top, right, bottom
def clamp_bbox(
bbox: Tuple[int, int, int, int],
image_size: Tuple[int, int],
) -> Tuple[int, int, int, int]:
w, h = image_size
left, top, right, bottom = bbox
left = max(0, min(left, w))
right = max(0, min(right, w))
top = max(0, min(top, h))
bottom = max(0, min(bottom, h))
return left, top, right, bottom
def expand_bbox(
bbox: Tuple[int, int, int, int],
image_size: Tuple[int, int],
padding: int,
min_side: int,
) -> Tuple[int, int, int, int]:
image_w, image_h = image_size
left, top, right, bottom = bbox
left -= padding
top -= padding
right += padding
bottom += padding
left, top, right, bottom = clamp_bbox(
(left, top, right, bottom),
image_size,
)
crop_w = right - left
crop_h = bottom - top
if crop_w < min_side:
extra = min_side - crop_w
left -= extra // 2
right += extra - extra // 2
if crop_h < min_side:
extra = min_side - crop_h
top -= extra // 2
bottom += extra - extra // 2
if left < 0:
right -= left
left = 0
if top < 0:
bottom -= top
top = 0
if right > image_w:
shift = right - image_w
left -= shift
right = image_w
if bottom > image_h:
shift = bottom - image_h
top -= shift
bottom = image_h
left, top, right, bottom = clamp_bbox(
(left, top, right, bottom),
image_size,
)
return left, top, right, bottom
def pad_to_multiple_of_8(
image: Image.Image,
mask: Image.Image,
) -> Tuple[Image.Image, Image.Image, Tuple[int, int], Tuple[int, int]]:
original_w, original_h = image.size
padded_w = ((original_w + 7) // 8) * 8
padded_h = ((original_h + 7) // 8) * 8
pad_w = padded_w - original_w
pad_h = padded_h - original_h
if pad_w == 0 and pad_h == 0:
return image, mask, (original_w, original_h), (padded_w, padded_h)
image = ImageOps.expand(
image,
border=(0, 0, pad_w, pad_h),
fill=0,
)
mask = ImageOps.expand(
mask,
border=(0, 0, pad_w, pad_h),
fill=0,
)
return image, mask, (original_w, original_h), (padded_w, padded_h)
def make_crop_inputs(
image: Image.Image,
mask: Image.Image,
crop_padding: int,
min_crop_side: int,
):
raw_bbox = mask_bbox(mask)
crop_bbox = expand_bbox(
raw_bbox,
image.size,
padding=int(crop_padding),
min_side=int(min_crop_side),
)
crop_image = image.crop(crop_bbox)
crop_mask = mask.crop(crop_bbox)
padded_image, padded_mask, crop_size, padded_size = pad_to_multiple_of_8(
crop_image,
crop_mask,
)
return {
"raw_bbox": raw_bbox,
"crop_bbox": crop_bbox,
"crop_image": crop_image,
"crop_mask": crop_mask,
"padded_image": padded_image,
"padded_mask": padded_mask,
"crop_size": crop_size,
"padded_size": padded_size,
}
def run_inpaint(
editor_value,
model_label: str,
prompt: str,
crop_padding: int,
work_mode: str,
min_working_window: int,
steps: int,
guidance_scale: float,
strength: float,
seed: int,
random_seed: bool,
mask_blur: int,
paste_whole_crop: bool,
):
if not prompt or not prompt.strip():
raise gr.Error("Write a prompt for the masked area.")
model_id = resolve_model_id(model_label)
pipe = load_pipe(model_label)
original_image, original_mask = extract_image_and_mask(editor_value)
original_image = original_image.convert("RGB")
original_mask = original_mask.convert("L")
if mask_blur > 0:
generation_mask_source = original_mask.filter(
ImageFilter.GaussianBlur(radius=int(mask_blur))
)
else:
generation_mask_source = original_mask
if work_mode == "Crop around mask":
crop_info = make_crop_inputs(
original_image,
generation_mask_source,
crop_padding=int(crop_padding),
min_crop_side=int(min_working_window),
)
gen_image = crop_info["padded_image"]
gen_mask = crop_info["padded_mask"]
crop_bbox = crop_info["crop_bbox"]
crop_size = crop_info["crop_size"]
padded_size = crop_info["padded_size"]
else:
crop_info = None
gen_image, gen_mask, crop_size, padded_size = pad_to_multiple_of_8(
original_image,
generation_mask_source,
)
crop_bbox = (0, 0, original_image.size[0], original_image.size[1])
if random_seed:
seed = random.randint(0, 2**31 - 1)
generator = torch.Generator(device=DEVICE).manual_seed(int(seed))
start = time.perf_counter()
with torch.inference_mode():
generated = pipe(
prompt=prompt.strip(),
image=gen_image,
mask_image=gen_mask,
height=gen_image.size[1],
width=gen_image.size[0],
num_inference_steps=int(steps),
guidance_scale=float(guidance_scale),
strength=float(strength),
generator=generator,
).images[0].convert("RGB")
elapsed = time.perf_counter() - start
if padded_size != crop_size:
generated = generated.crop((0, 0, crop_size[0], crop_size[1]))
if work_mode == "Crop around mask":
left, top, right, bottom = crop_bbox
final_result = original_image.copy()
if paste_whole_crop:
final_result.paste(generated, (left, top))
paste_mode = "whole generated crop"
else:
final_mask = original_mask.crop(crop_bbox)
if mask_blur > 0:
final_mask = final_mask.filter(
ImageFilter.GaussianBlur(radius=int(mask_blur))
)
final_result.paste(generated, (left, top), final_mask)
paste_mode = "masked area only"
crop_report = (
f"- raw mask bbox: `{crop_info['raw_bbox']}`\n"
f"- working crop bbox: `{crop_bbox}`\n"
f"- working crop size: `{crop_size[0]}x{crop_size[1]}`\n"
f"- generation size: `{padded_size[0]}x{padded_size[1]}`\n"
f"- paste mode: `{paste_mode}`"
)
else:
if paste_whole_crop:
final_result = generated
paste_mode = "whole generated image"
else:
final_mask = original_mask
if mask_blur > 0:
final_mask = final_mask.filter(
ImageFilter.GaussianBlur(radius=int(mask_blur))
)
final_result = Image.composite(generated, original_image, final_mask)
paste_mode = "masked area only"
crop_report = (
"- working crop: `none`\n"
f"- original size: `{original_image.size[0]}x{original_image.size[1]}`\n"
f"- generation size: `{padded_size[0]}x{padded_size[1]}`\n"
f"- paste mode: `{paste_mode}`"
)
padding_used = "yes" if padded_size != crop_size else "no"
info = (
"**Done**\n\n"
f"- device: `{DEVICE}`\n"
f"- selected model: `{model_label}`\n"
f"- model id: `{model_id}`\n"
"- speed trick: `LCM-LoRA`\n"
"- vae: `TAESD`\n"
f"- work mode: `{work_mode}`\n"
"- resize: `none`\n"
f"- context around mask: `{crop_padding}` px\n"
f"- minimum working window: `{min_working_window}` px\n"
f"- padding to multiple of 8: `{padding_used}`\n"
f"{crop_report}\n"
f"- steps: `{steps}`\n"
f"- guidance: `{guidance_scale}`\n"
f"- strength: `{strength}`\n"
f"- mask blur: `{mask_blur}`\n"
f"- seed: `{seed}`\n"
f"- time: `{elapsed:.1f}s`\n\n"
"This version does not downscale or upscale. "
"In crop mode it sends the mask bbox plus surrounding context to the model, "
"then pastes the result back into the original image."
)
return final_result, info
with gr.Blocks(title="SD15 Light Inpaint CPU") as demo:
gr.Markdown(
"# SD15 Light Inpaint CPU\n\n"
"Upload an image, draw over the area you want to repaint, and describe what should appear there.\n\n"
"Default mode uses a local crop around the mask instead of resizing the whole image. "
"This keeps the model working at the original local visual scale while still reducing the amount of image sent to the model."
)
with gr.Row():
with gr.Column():
editor = gr.ImageEditor(
label="Image + mask",
type="pil",
)
model_label = gr.Dropdown(
label="Base inpaint model",
choices=list(MODEL_CHOICES.keys()),
value=DEFAULT_MODEL_LABEL,
)
prompt = gr.Textbox(
label="Prompt for masked area",
value="a beautiful fantasy detail, coherent with the original image, natural lighting",
lines=3,
)
crop_padding = gr.Slider(
label="Context around mask",
minimum=32,
maximum=384,
step=32,
value=128,
)
with gr.Accordion("Advanced", open=False):
work_mode = gr.Radio(
label="Work mode",
choices=[
"Crop around mask",
"Full image no resize",
],
value="Crop around mask",
)
min_working_window = gr.Slider(
label="Minimum working window size",
minimum=128,
maximum=768,
step=64,
value=384,
)
paste_whole_crop = gr.Checkbox(
label="Paste whole generated crop back, diagnostic",
value=False,
)
steps = gr.Slider(
label="Steps",
minimum=2,
maximum=8,
step=1,
value=4,
)
guidance_scale = gr.Slider(
label="Guidance scale / CFG. LCM usually works best around 1.0-2.0",
minimum=1.0,
maximum=8.0,
step=0.1,
value=1.5,
)
strength = gr.Slider(
label="Strength",
minimum=0.3,
maximum=1.0,
step=0.05,
value=0.85,
)
mask_blur = gr.Slider(
label="Mask blur",
minimum=0,
maximum=16,
step=1,
value=0,
)
seed = gr.Number(
label="Seed",
value=12345,
precision=0,
)
random_seed = gr.Checkbox(
label="Random seed",
value=False,
)
button = gr.Button("Generate", variant="primary")
with gr.Column():
output = gr.Image(label="Result", type="pil")
info = gr.Markdown()
button.click(
fn=run_inpaint,
inputs=[
editor,
model_label,
prompt,
crop_padding,
work_mode,
min_working_window,
steps,
guidance_scale,
strength,
seed,
random_seed,
mask_blur,
paste_whole_crop,
],
outputs=[output, info],
)
demo.queue(max_size=4).launch(
server_name="0.0.0.0",
server_port=7860,
ssr_mode=False,
)