Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import numpy as np | |
| import torch | |
| from PIL import Image | |
| import spaces | |
| from omegaconf import OmegaConf | |
| import subprocess | |
| rc = subprocess.call("./setup.sh") | |
| import sys | |
| import os | |
| sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'lama')) | |
| from lama.saicinpainting.evaluation.refinement import refine_predict | |
| from lama.saicinpainting.training.trainers import load_checkpoint | |
| from lama.saicinpainting.evaluation.utils import move_to_device | |
| # Load the model | |
| def get_inpaint_model(): | |
| """ | |
| Loads and initializes the inpainting model. | |
| Returns: Tuple of (model, predict_config) | |
| """ | |
| predict_config = OmegaConf.load('./default.yaml') | |
| predict_config.model.path = './big-lama/models/' | |
| predict_config.refiner.gpu_ids = '0' | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| # Instead of setting device directly, we'll use it when loading the model | |
| predict_config.device = str(device) # Store as string in config | |
| train_config_path = './big-lama/config.yaml' | |
| train_config = OmegaConf.load(train_config_path) | |
| train_config.training_model.predict_only = True | |
| train_config.visualizer.kind = 'noop' | |
| checkpoint_path = os.path.join(predict_config.model.path, | |
| predict_config.model.checkpoint) | |
| model = load_checkpoint(train_config, checkpoint_path, strict=False, map_location=device) | |
| model.freeze() | |
| model.to(device) | |
| return model, predict_config | |
| def inpaint(input_dict, refinement_enabled=False): | |
| """ | |
| Performs image inpainting on the input image using the provided mask. | |
| Args: input_dict containing 'background' (image) and 'layers' (mask) | |
| Returns: Tuple of (output_image, input_mask) | |
| """ | |
| input_image = np.array(input_dict["background"].convert("RGB")).astype('float32') / 255 | |
| input_mask = pil_to_binary_mask(input_dict['layers'][0]) | |
| np_input_image = np.transpose(np.array(input_image), (2, 0, 1)) | |
| np_input_mask = np.array(input_mask)[None, :, :] # Add channel dimension for grayscale images | |
| batch = dict(image=np_input_image, mask=np_input_mask) | |
| inpaint_model, predict_config = get_inpaint_model() | |
| device = torch.device(predict_config.device) | |
| batch['unpad_to_size'] = [torch.tensor([batch['image'].shape[1]]),torch.tensor([batch['image'].shape[2]])] | |
| batch['image'] = torch.tensor(pad_img_to_modulo(batch['image'], predict_config.dataset.pad_out_to_modulo))[None].to(device) | |
| batch['mask'] = torch.tensor(pad_img_to_modulo(batch['mask'], predict_config.dataset.pad_out_to_modulo))[None].float().to(device) | |
| if refinement_enabled is True: | |
| cur_res = refine_predict(batch, inpaint_model, **predict_config.refiner) | |
| cur_res = cur_res[0].permute(1,2,0).detach().cpu().numpy() | |
| else: | |
| with torch.no_grad(): | |
| batch = move_to_device(batch, device) | |
| batch['mask'] = (batch['mask'] > 0) * 1 | |
| batch = inpaint_model(batch) | |
| cur_res = batch[predict_config.out_key][0].permute(1, 2, 0).detach().cpu().numpy() | |
| unpad_to_size = batch.get('unpad_to_size', None) | |
| if unpad_to_size is not None: | |
| orig_height, orig_width = unpad_to_size | |
| cur_res = cur_res[:orig_height, :orig_width] | |
| cur_res = np.clip(cur_res * 255, 0, 255).astype('uint8') | |
| output_image = Image.fromarray(cur_res) | |
| return output_image | |
| def ceil_modulo(x, mod): | |
| if x % mod == 0: | |
| return x | |
| return (x // mod + 1) * mod | |
| def pad_img_to_modulo(img, mod): | |
| channels, height, width = img.shape | |
| out_height = ceil_modulo(height, mod) | |
| out_width = ceil_modulo(width, mod) | |
| return np.pad(img, ((0, 0), (0, out_height - height), (0, out_width - width)), mode='symmetric') | |
| def pil_to_binary_mask(pil_image, threshold=0, max_scale=1): | |
| """ | |
| Converts a PIL image to a binary mask. | |
| Args: | |
| pil_image (PIL.Image): The input PIL image. | |
| threshold (int, optional): The threshold value for binarization. Defaults to 0. | |
| Returns: | |
| PIL.Image: A grayscale PIL image representing the binary mask. | |
| """ | |
| np_image = np.array(pil_image) | |
| grayscale_image = Image.fromarray(np_image).convert("L") | |
| binary_mask = np.array(grayscale_image) > threshold | |
| mask = np.zeros(binary_mask.shape, dtype=np.uint8) | |
| for i in range(binary_mask.shape[0]): | |
| for j in range(binary_mask.shape[1]): | |
| if binary_mask[i,j] == True : | |
| mask[i,j] = 1 | |
| mask = (mask*max_scale).astype(np.uint8) | |
| output_mask = Image.fromarray(mask) | |
| # Convert mask to grayscale | |
| return output_mask.convert("L") | |
| css = ".output-image, .input-image, .image-preview {height: 600px !important}" | |
| # Create Gradio interface | |
| with gr.Blocks(css=css) as demo: | |
| gr.Markdown("# Image Inpainting") | |
| gr.Markdown("Upload an image and draw a mask to remove unwanted objects.") | |
| with gr.Row(): | |
| input_image = gr.ImageEditor(type="pil", label='Input image & Mask', interactive=True, height="auto", width="auto", brush=gr.Brush(colors=['#f2e2cd'], default_size=25)) | |
| output_image = gr.Image(type="pil", label="Output Image", height="auto", width="auto") | |
| with gr.Row(): | |
| refine_checkbox = gr.Checkbox(label="Enable Refinement[SLOWER BUT BETTER]", value=False) | |
| inpaint_button = gr.Button("Inpaint") | |
| def inpaint_with_refinement(image, enable_refinement): | |
| return inpaint(image, refinement_enabled=enable_refinement) | |
| inpaint_button.click( | |
| fn=inpaint_with_refinement, | |
| inputs=[input_image, refine_checkbox], | |
| outputs=[output_image] | |
| ) | |
| # Launch the interface | |
| if __name__ == "__main__": | |
| demo.launch() | |