Spaces:
Running on Zero
Running on Zero
| #!/usr/bin/env python | |
| import os | |
| import gradio as gr | |
| import numpy as np | |
| import PIL.Image | |
| import spaces | |
| import torch | |
| from transformers import VitMatteForImageMatting, VitMatteImageProcessor | |
| DESCRIPTION = "# [ViTMatte](https://github.com/hustvl/ViTMatte)" | |
| device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
| MAX_IMAGE_SIZE = int(os.getenv("MAX_IMAGE_SIZE", "1500")) | |
| MODEL_ID = os.getenv("MODEL_ID", "hustvl/vitmatte-small-distinctions-646") | |
| processor = VitMatteImageProcessor.from_pretrained(MODEL_ID) | |
| model = VitMatteForImageMatting.from_pretrained(MODEL_ID).to(device) | |
| def check_image_size(image: PIL.Image.Image) -> None: | |
| if max(image.size) > MAX_IMAGE_SIZE: | |
| raise gr.Error(f"Image size is too large. Max image size is {MAX_IMAGE_SIZE} pixels.") | |
| def binarize_mask(mask: np.ndarray) -> np.ndarray: | |
| mask[mask < 128] = 0 | |
| mask[mask > 0] = 1 | |
| return mask | |
| def update_trimap(foreground_mask: dict[str, np.ndarray], unknown_mask: dict[str, np.ndarray]) -> np.ndarray: | |
| foreground = foreground_mask["mask"][:, :, 0] | |
| foreground = binarize_mask(foreground) | |
| unknown = unknown_mask["mask"][:, :, 0] | |
| unknown = binarize_mask(unknown) | |
| trimap = np.zeros_like(foreground) | |
| trimap[unknown > 0] = 128 | |
| trimap[foreground > 0] = 255 | |
| return trimap | |
| def adjust_background_image(background_image: PIL.Image.Image, target_size: tuple[int, int]) -> PIL.Image.Image: | |
| target_w, target_h = target_size | |
| bg_w, bg_h = background_image.size | |
| scale = max(target_w / bg_w, target_h / bg_h) | |
| new_bg_w = int(bg_w * scale) | |
| new_bg_h = int(bg_h * scale) | |
| background_image = background_image.resize((new_bg_w, new_bg_h)) | |
| left = (new_bg_w - target_w) // 2 | |
| top = (new_bg_h - target_h) // 2 | |
| right = left + target_w | |
| bottom = top + target_h | |
| background_image = background_image.crop((left, top, right, bottom)) | |
| return background_image | |
| def replace_background( | |
| image: PIL.Image.Image, alpha: np.ndarray, background_image: PIL.Image.Image | None | |
| ) -> PIL.Image.Image | None: | |
| if background_image is None: | |
| return None | |
| if image.mode != "RGB": | |
| raise gr.Error("Image must be RGB.") | |
| background_image = background_image.convert("RGB") | |
| background_image = adjust_background_image(background_image, image.size) | |
| image = np.array(image).astype(float) / 255 | |
| background_image = np.array(background_image).astype(float) / 255 | |
| result = image * alpha[:, :, None] + background_image * (1 - alpha[:, :, None]) | |
| result = (result * 255).astype(np.uint8) | |
| return result | |
| def run( | |
| image: PIL.Image.Image, | |
| trimap: PIL.Image.Image, | |
| apply_background_replacement: bool, | |
| background_image: PIL.Image.Image | None, | |
| ) -> tuple[np.ndarray, PIL.Image.Image, PIL.Image.Image | None]: | |
| if image.size != trimap.size: | |
| raise gr.Error("Image and trimap must have the same size.") | |
| if max(image.size) > MAX_IMAGE_SIZE: | |
| raise gr.Error(f"Image size is too large. Max image size is {MAX_IMAGE_SIZE} pixels.") | |
| if image.mode != "RGB": | |
| raise gr.Error("Image must be RGB.") | |
| if trimap.mode != "L": | |
| raise gr.Error("Trimap must be grayscale.") | |
| pixel_values = processor(images=image, trimaps=trimap, return_tensors="pt").to(device).pixel_values | |
| out = model(pixel_values=pixel_values) | |
| alpha = out.alphas[0, 0].to("cpu").numpy() | |
| w, h = image.size | |
| alpha = alpha[:h, :w] | |
| foreground = np.array(image).astype(float) / 255 * alpha[:, :, None] + (1 - alpha[:, :, None]) | |
| foreground = (foreground * 255).astype(np.uint8) | |
| foreground = PIL.Image.fromarray(foreground) | |
| if apply_background_replacement: | |
| res_bg_replacement = replace_background(image, alpha, background_image) | |
| else: | |
| res_bg_replacement = None | |
| return alpha, foreground, res_bg_replacement | |
| with gr.Blocks(css="style.css") as demo: | |
| gr.Markdown(DESCRIPTION) | |
| gr.Markdown("This is the demo for [ViTMatte](https://github.com/hustvl/ViTMatte), an image matting application. You can matte any subject in a given image.") | |
| gr.Markdown("If you wish to replace background of the image, simply select the checkbox and drag and drop your background image.") | |
| gr.Markdown("You can draw your own foreground mask and unknown (border) mask using the canvas.") | |
| gr.DuplicateButton( | |
| value="Duplicate Space for private use", | |
| elem_id="duplicate-button", | |
| visible=os.getenv("SHOW_DUPLICATE_BUTTON") == "1", | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| with gr.Box(): | |
| image = gr.Image(label="Input image", type="pil", height=500) | |
| with gr.Tabs(): | |
| with gr.Tab(label="Trimap"): | |
| trimap = gr.Image(label="Trimap", type="pil", image_mode="L", height=500) | |
| with gr.Tab(label="Draw trimap"): | |
| load_image_button = gr.Button("Load image") | |
| foreground_mask = gr.Image( | |
| label="Foreground", | |
| tool="sketch", | |
| type="numpy", | |
| brush_color="green", | |
| mask_opacity=0.7, | |
| height=500, | |
| ) | |
| unknown_mask = gr.Image( | |
| label="Unknown", | |
| tool="sketch", | |
| type="numpy", | |
| brush_color="green", | |
| mask_opacity=0.7, | |
| height=500, | |
| ) | |
| set_trimap_button = gr.Button("Set trimap") | |
| apply_background_replacement = gr.Checkbox(label="Apply background replacement", checked=False) | |
| background_image = gr.Image(label="Background image", type="pil", height=500, visible=False) | |
| run_button = gr.Button("Run") | |
| with gr.Column(): | |
| with gr.Box(): | |
| out_alpha = gr.Image(label="Alpha", height=500) | |
| out_foreground = gr.Image(label="Foreground", height=500) | |
| out_background_replacement = gr.Image(label="Background replacement", height=500, visible=False) | |
| inputs = [ | |
| image, | |
| trimap, | |
| apply_background_replacement, | |
| background_image, | |
| ] | |
| outputs = [ | |
| out_alpha, | |
| out_foreground, | |
| out_background_replacement, | |
| ] | |
| gr.Examples( | |
| examples=[ | |
| ["assets/retriever_rgb.png", "assets/retriever_trimap.png", False, None], | |
| ["assets/bulb_rgb.png", "assets/bulb_trimap.png", True, "assets/new_bg.jpg"], | |
| ], | |
| inputs=inputs, | |
| outputs=outputs, | |
| fn=run, | |
| cache_examples=os.getenv("CACHE_EXAMPLES") == "1", | |
| ) | |
| image.change( | |
| fn=check_image_size, | |
| inputs=image, | |
| queue=False, | |
| api_name=False, | |
| ) | |
| load_image_button.click( | |
| fn=lambda image: (image, image), | |
| inputs=image, | |
| outputs=[foreground_mask, unknown_mask], | |
| queue=False, | |
| api_name=False, | |
| ) | |
| set_trimap_button.click( | |
| fn=update_trimap, | |
| inputs=[foreground_mask, unknown_mask], | |
| outputs=trimap, | |
| queue=False, | |
| api_name=False, | |
| ) | |
| apply_background_replacement.change( | |
| fn=lambda checked: (gr.Image(visible=checked), gr.Image(visible=checked)), | |
| inputs=apply_background_replacement, | |
| outputs=[background_image, out_background_replacement], | |
| queue=False, | |
| api_name=False, | |
| ) | |
| run_button.click( | |
| fn=run, | |
| inputs=inputs, | |
| outputs=outputs, | |
| api_name="run", | |
| ) | |
| if __name__ == "__main__": | |
| demo.queue(max_size=20).launch() | |