| import gradio as gr | |
| from gradio.components.image_editor import EditorValue | |
| from gradio_imageslider import ImageSlider | |
| from PIL import Image | |
| from typing import cast | |
| import numpy as np | |
| from simple_lama_inpainting import SimpleLama | |
| simple_lama = SimpleLama() | |
| def HWC3(x): | |
| if x.ndim == 2: | |
| x = x[:, :, None] | |
| H, W, C = x.shape | |
| if C == 3: | |
| return x | |
| if C == 1: | |
| return np.concatenate([x, x, x], axis=2) | |
| if C == 4: | |
| color = x[:, :, 0:3].astype(np.float32) | |
| alpha = x[:, :, 3:4].astype(np.float32) / 255.0 | |
| y = color * alpha + 255.0 * (1.0 - alpha) | |
| y = y.clip(0, 255).astype(np.uint8) | |
| return y | |
| def process_image( | |
| image: Image.Image | str | None, | |
| mask: Image.Image | str | None, | |
| progress: gr.Progress = gr.Progress(), | |
| ) -> Image.Image | None: | |
| progress(0, desc="Preparing inputs...") | |
| if image is None or mask is None: | |
| return None | |
| if isinstance(mask, str): | |
| mask = Image.open(mask) | |
| if isinstance(image, str): | |
| image = Image.open(image) | |
| image = np.array(image) | |
| image = HWC3(image) | |
| result = simple_lama(image, mask) | |
| result.save("inpainted.png") | |
| return result | |
| def resize_image(img: Image.Image, min_side_length: int = 768) -> Image.Image: | |
| if img.width <= min_side_length and img.height <= min_side_length: | |
| return img | |
| aspect_ratio = img.width / img.height | |
| if img.width < img.height: | |
| new_height = int(min_side_length / aspect_ratio) | |
| return img.resize((min_side_length, new_height)) | |
| new_width = int(min_side_length * aspect_ratio) | |
| return img.resize((new_width, min_side_length)) | |
| async def process( | |
| image_and_mask: EditorValue | None, | |
| progress: gr.Progress = gr.Progress(), | |
| ) -> tuple[Image.Image, Image.Image] | None: | |
| if not image_and_mask: | |
| gr.Info("Please upload an image and draw a mask") | |
| return None | |
| image_np = image_and_mask["background"] | |
| image_np = cast(np.ndarray, image_np) | |
| if np.sum(image_np) == 0: | |
| gr.Info("Please upload an image") | |
| return None | |
| alpha_channel = image_and_mask["layers"][0] | |
| alpha_channel = cast(np.ndarray, alpha_channel) | |
| mask_np = np.where(alpha_channel[:, :, 3] == 0, 0, 255).astype(np.uint8) | |
| if np.sum(mask_np) == 0: | |
| gr.Info("Please mark the areas you want to remove") | |
| return None | |
| mask = Image.fromarray(mask_np) | |
| mask = resize_image(mask) | |
| image = Image.fromarray(image_np) | |
| image = resize_image(image) | |
| output = process_image( | |
| image, | |
| mask, | |
| progress, | |
| ) | |
| if output is None: | |
| gr.Info("Processing failed") | |
| return None | |
| progress(100, desc="Processing completed") | |
| return image, output | |
| with gr.Blocks() as demo: | |
| with gr.Row(): | |
| with gr.Column(): | |
| image_and_mask = gr.ImageMask( | |
| label="Upload Image and Draw Mask", | |
| layers=False, | |
| show_fullscreen_button=False, | |
| sources=["upload"], | |
| show_download_button=False, | |
| interactive=True, | |
| height="full", | |
| width="full", | |
| brush=gr.Brush(default_size=75, colors=["#000000"], color_mode="fixed"), | |
| transforms=[], | |
| ) | |
| with gr.Column(): | |
| image_slider = ImageSlider( | |
| label="Result", | |
| interactive=False, | |
| ) | |
| process_btn = gr.ClearButton( | |
| value="Run", | |
| variant="primary", | |
| size="lg", | |
| components=[image_slider], | |
| ) | |
| process_btn.click( | |
| fn=lambda _: gr.update(interactive=False, value="Processing..."), | |
| inputs=[], | |
| outputs=[process_btn], | |
| api_name=False, | |
| ).then( | |
| fn=process, | |
| inputs=[ | |
| image_and_mask, | |
| ], | |
| outputs=[image_slider], | |
| api_name=False, | |
| ).then( | |
| fn=lambda _: gr.update(interactive=True, value="Run"), | |
| inputs=[], | |
| outputs=[process_btn], | |
| api_name=False, | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch( | |
| debug=False, | |
| share=False, | |
| show_api=False, | |
| ) |