| | %cd rem |
| | |
| |
|
| | import gradio as gr |
| | import numpy as np |
| | import torch |
| | from src.pipeline_stable_diffusion_controlnet_inpaint import * |
| |
|
| | from diffusers import StableDiffusionInpaintPipeline, ControlNetModel, DEISMultistepScheduler |
| | from diffusers.utils import load_image |
| | from PIL import Image |
| | import cv2 |
| | from src.core import process_inpaint |
| | from transformers import DPTFeatureExtractor, DPTForDepthEstimation |
| | import time |
| |
|
| | from scipy.ndimage import label, find_objects |
| | from PIL import Image, ImageDraw |
| | import numpy as np |
| |
|
| | depth_estimator = DPTForDepthEstimation.from_pretrained("Intel/dpt-hybrid-midas").to("cuda") |
| | feature_extractor = DPTFeatureExtractor.from_pretrained("Intel/dpt-hybrid-midas") |
| | controlnet = ControlNetModel.from_pretrained("thibaud/controlnet-sd21-depth-diffusers", torch_dtype=torch.float16) |
| |
|
| |
|
| |
|
| | pipe = StableDiffusionControlNetInpaintPipeline.from_pretrained( |
| | "stabilityai/stable-diffusion-2-inpainting",controlnet=controlnet, torch_dtype=torch.float16) |
| |
|
| | pipe.scheduler = DEISMultistepScheduler.from_config(pipe.scheduler.config) |
| |
|
| | pipe.to('cuda') |
| |
|
| | def resize_image(image, target_size): |
| | width, height = image.size |
| | aspect_ratio = float(width) / float(height) |
| | if width > height: |
| | new_width = target_size |
| | new_height = int(target_size / aspect_ratio) |
| | else: |
| | new_width = int(target_size * aspect_ratio) |
| | new_height = target_size |
| | return image.resize((new_width, new_height), Image.BICUBIC) |
| |
|
| | def get_depth_map(image,target_size): |
| | image = feature_extractor(images=image, return_tensors="pt").pixel_values.to("cuda") |
| | with torch.no_grad(), torch.autocast("cuda"): |
| | depth_map = depth_estimator(image).predicted_depth |
| |
|
| | depth_map = torch.nn.functional.interpolate( |
| | depth_map.unsqueeze(1), |
| | size=target_size, |
| | mode="bicubic", |
| | align_corners=False, |
| | ) |
| | depth_min = torch.amin(depth_map, dim=[1, 2, 3], keepdim=True) |
| | depth_max = torch.amax(depth_map, dim=[1, 2, 3], keepdim=True) |
| | depth_map = (depth_map - depth_min) / (depth_max - depth_min) |
| | image = torch.cat([depth_map] * 3, dim=1) |
| |
|
| | image = image.permute(0, 2, 3, 1).cpu().numpy()[0] |
| | image = Image.fromarray((image * 255.0).clip(0, 255).astype(np.uint8)) |
| | return image |
| |
|
| | def add_split_line(mask_image, line_thickness): |
| | |
| | if mask_image.mode != 'L': |
| | mask_image = mask_image.convert('L') |
| |
|
| | |
| | mask_array = np.array(mask_image) |
| |
|
| | |
| | labeled_array, num_features = label(mask_array == 255) |
| |
|
| | |
| | draw = ImageDraw.Draw(mask_image) |
| |
|
| | |
| | for i in range(1, num_features + 1): |
| | |
| | slice_x, slice_y = find_objects(labeled_array == i)[0] |
| | top, bottom = slice_x.start, slice_x.stop |
| | left, right = slice_y.start, slice_y.stop |
| |
|
| | |
| | if (right - left) > (bottom - top): |
| | |
| | center_x = (left + right) // 2 |
| | draw.line([(center_x, top), (center_x, bottom)], fill=0, width=line_thickness) |
| | else: |
| | |
| | center_y = (top + bottom) // 2 |
| | draw.line([(left, center_y), (right, center_y)], fill=0, width=line_thickness) |
| |
|
| | return mask_image |
| |
|
| | def predict(input_dict): |
| | start_time = time.time() |
| |
|
| | |
| | image = input_dict["image"].convert("RGB") |
| | input_image = input_dict["mask"].convert("RGBA") |
| | image = resize_image(image, 768) |
| | input_image = resize_image(input_image, 768) |
| | mask_holes = add_split_line(input_image, 10) |
| |
|
| | |
| | image_npp = np.array(image) |
| | drawing_np = np.array(input_image) |
| |
|
| | if image_npp.shape[2] == 4: |
| | image_npp = cv2.cvtColor(image_npp, cv2.COLOR_RGBA2RGB) |
| |
|
| | |
| | background = np.where( |
| | (drawing_np[:, :, 0] == 0) & |
| | (drawing_np[:, :, 1] == 0) & |
| | (drawing_np[:, :, 2] == 0) |
| | ) |
| | drawing = np.where( |
| | (drawing_np[:, :, 0] == 255) & |
| | (drawing_np[:, :, 1] == 0) & |
| | (drawing_np[:, :, 2] == 255) |
| | ) |
| | mask_npp = np.zeros_like(drawing_np) |
| | mask_npp[background] = [0, 0, 0, 255] |
| | mask_npp[drawing] = [0, 0, 0, 0] |
| |
|
| | |
| | inpainted_image_np = process_inpaint(image_npp, mask_npp) |
| | inpainted_image = Image.fromarray(inpainted_image_np) |
| |
|
| | unmasked_region = np.where(mask_npp[:, :, 3] != 0, True, False) |
| |
|
| | |
| | blended_image_np = np.array(inpainted_image_np) |
| |
|
| | blended_image_size = inpainted_image.size |
| |
|
| | |
| | flipped_size = (blended_image_size[1], blended_image_size[0]) |
| | depth_image = get_depth_map(inpainted_image, flipped_size) |
| |
|
| |
|
| | generator = torch.manual_seed(0) |
| | output = pipe( |
| | prompt="", |
| | num_inference_steps=8, |
| | generator=generator, |
| | image=blended_image_np, |
| | control_image=depth_image, |
| | controlnet_conditioning_scale=0.9, |
| | mask_image=mask_holes |
| | ).images[0] |
| |
|
| | |
| | output_np = np.array(output) |
| |
|
| | |
| | if output_np.shape[:2] == inpainted_image_np.shape[:2]: |
| | |
| | output_np[unmasked_region] = inpainted_image_np[unmasked_region] |
| | else: |
| | print("Dimension mismatch: cannot apply unmasked_region") |
| |
|
| | |
| | final_output = Image.fromarray(output_np) |
| |
|
| | end_time = time.time() |
| | inference_time = end_time - start_time |
| | inference_time_str = f"Inference Time: {inference_time:.2f} seconds" |
| |
|
| | |
| | return final_output, inference_time_str |
| |
|
| | image_blocks = gr.Blocks() |
| |
|
| | with image_blocks as demo: |
| | with gr.Row(): |
| | with gr.Column(): |
| | input_image = gr.Image(source='upload', tool='sketch', elem_id="input_image_upload", type="pil", label="Upload & Draw on Image") |
| | btn = gr.Button("Remove Object") |
| | with gr.Column(): |
| | result = gr.Image(label="Result") |
| | inference_time_label = gr.Label() |
| | btn.click(fn=predict, inputs=[input_image], outputs=[result, inference_time_label]) |
| |
|
| | demo.launch(debug=True) |