Spaces:
Sleeping
Sleeping
File size: 5,744 Bytes
cc8944c d911050 cc8944c 6f3f66a d911050 ba3e3be d911050 6f3f66a d911050 cc8944c d911050 cc8944c 6f3f66a d911050 6f3f66a d911050 cc8944c d911050 cc8944c d911050 cc8944c d911050 6f3f66a d911050 cc8944c 6f3f66a d911050 cc8944c 6f3f66a cc8944c d911050 cc8944c 6f3f66a cc8944c 6f3f66a cc8944c 6f3f66a cc8944c 6f3f66a cc8944c 6f3f66a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 |
import gradio as gr
import numpy as np
import torch
from PIL import Image
import spaces
from omegaconf import OmegaConf
import subprocess
rc = subprocess.call("./setup.sh")
import sys
import os
sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'lama'))
from lama.saicinpainting.evaluation.refinement import refine_predict
from lama.saicinpainting.training.trainers import load_checkpoint
from lama.saicinpainting.evaluation.utils import move_to_device
# Load the model
def get_inpaint_model():
"""
Loads and initializes the inpainting model.
Returns: Tuple of (model, predict_config)
"""
predict_config = OmegaConf.load('./default.yaml')
predict_config.model.path = './big-lama/models/'
predict_config.refiner.gpu_ids = '0'
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Instead of setting device directly, we'll use it when loading the model
predict_config.device = str(device) # Store as string in config
train_config_path = './big-lama/config.yaml'
train_config = OmegaConf.load(train_config_path)
train_config.training_model.predict_only = True
train_config.visualizer.kind = 'noop'
checkpoint_path = os.path.join(predict_config.model.path,
predict_config.model.checkpoint)
model = load_checkpoint(train_config, checkpoint_path, strict=False, map_location=device)
model.freeze()
model.to(device)
return model, predict_config
@spaces.GPU
def inpaint(input_dict, refinement_enabled=False):
"""
Performs image inpainting on the input image using the provided mask.
Args: input_dict containing 'background' (image) and 'layers' (mask)
Returns: Tuple of (output_image, input_mask)
"""
input_image = np.array(input_dict["background"].convert("RGB")).astype('float32') / 255
input_mask = pil_to_binary_mask(input_dict['layers'][0])
np_input_image = np.transpose(np.array(input_image), (2, 0, 1))
np_input_mask = np.array(input_mask)[None, :, :] # Add channel dimension for grayscale images
batch = dict(image=np_input_image, mask=np_input_mask)
inpaint_model, predict_config = get_inpaint_model()
device = torch.device(predict_config.device)
batch['unpad_to_size'] = [torch.tensor([batch['image'].shape[1]]),torch.tensor([batch['image'].shape[2]])]
batch['image'] = torch.tensor(pad_img_to_modulo(batch['image'], predict_config.dataset.pad_out_to_modulo))[None].to(device)
batch['mask'] = torch.tensor(pad_img_to_modulo(batch['mask'], predict_config.dataset.pad_out_to_modulo))[None].float().to(device)
if refinement_enabled is True:
cur_res = refine_predict(batch, inpaint_model, **predict_config.refiner)
cur_res = cur_res[0].permute(1,2,0).detach().cpu().numpy()
else:
with torch.no_grad():
batch = move_to_device(batch, device)
batch['mask'] = (batch['mask'] > 0) * 1
batch = inpaint_model(batch)
cur_res = batch[predict_config.out_key][0].permute(1, 2, 0).detach().cpu().numpy()
unpad_to_size = batch.get('unpad_to_size', None)
if unpad_to_size is not None:
orig_height, orig_width = unpad_to_size
cur_res = cur_res[:orig_height, :orig_width]
cur_res = np.clip(cur_res * 255, 0, 255).astype('uint8')
output_image = Image.fromarray(cur_res)
return output_image
def ceil_modulo(x, mod):
if x % mod == 0:
return x
return (x // mod + 1) * mod
def pad_img_to_modulo(img, mod):
channels, height, width = img.shape
out_height = ceil_modulo(height, mod)
out_width = ceil_modulo(width, mod)
return np.pad(img, ((0, 0), (0, out_height - height), (0, out_width - width)), mode='symmetric')
def pil_to_binary_mask(pil_image, threshold=0, max_scale=1):
"""
Converts a PIL image to a binary mask.
Args:
pil_image (PIL.Image): The input PIL image.
threshold (int, optional): The threshold value for binarization. Defaults to 0.
Returns:
PIL.Image: A grayscale PIL image representing the binary mask.
"""
np_image = np.array(pil_image)
grayscale_image = Image.fromarray(np_image).convert("L")
binary_mask = np.array(grayscale_image) > threshold
mask = np.zeros(binary_mask.shape, dtype=np.uint8)
for i in range(binary_mask.shape[0]):
for j in range(binary_mask.shape[1]):
if binary_mask[i,j] == True :
mask[i,j] = 1
mask = (mask*max_scale).astype(np.uint8)
output_mask = Image.fromarray(mask)
# Convert mask to grayscale
return output_mask.convert("L")
css = ".output-image, .input-image, .image-preview {height: 600px !important}"
# Create Gradio interface
with gr.Blocks(css=css) as demo:
gr.Markdown("# Image Inpainting")
gr.Markdown("Upload an image and draw a mask to remove unwanted objects.")
with gr.Row():
input_image = gr.ImageEditor(type="pil", label='Input image & Mask', interactive=True, height="auto", width="auto", brush=gr.Brush(colors=['#f2e2cd'], default_size=25))
output_image = gr.Image(type="pil", label="Output Image", height="auto", width="auto")
with gr.Row():
refine_checkbox = gr.Checkbox(label="Enable Refinement[SLOWER BUT BETTER]", value=False)
inpaint_button = gr.Button("Inpaint")
def inpaint_with_refinement(image, enable_refinement):
return inpaint(image, refinement_enabled=enable_refinement)
inpaint_button.click(
fn=inpaint_with_refinement,
inputs=[input_image, refine_checkbox],
outputs=[output_image]
)
# Launch the interface
if __name__ == "__main__":
demo.launch()
|