File size: 5,744 Bytes
cc8944c
 
d911050
cc8944c
6f3f66a
d911050
 
ba3e3be
 
 
d911050
 
 
 
 
 
6f3f66a
d911050
cc8944c
 
d911050
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cc8944c
6f3f66a
 
d911050
 
 
 
 
6f3f66a
d911050
cc8944c
d911050
 
 
cc8944c
d911050
 
cc8944c
d911050
 
 
 
6f3f66a
 
 
 
 
 
 
 
 
 
 
 
 
 
d911050
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cc8944c
6f3f66a
d911050
 
 
 
 
 
 
 
 
 
cc8944c
 
 
 
 
 
 
 
6f3f66a
cc8944c
d911050
 
cc8944c
6f3f66a
 
cc8944c
6f3f66a
cc8944c
 
 
 
6f3f66a
 
 
 
 
 
 
 
 
cc8944c
6f3f66a
 
 
 
 
cc8944c
 
 
6f3f66a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
import gradio as gr
import numpy as np
import torch
from PIL import Image
import spaces
from omegaconf import OmegaConf

import subprocess
rc = subprocess.call("./setup.sh")

import sys 
import os
sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'lama'))

from lama.saicinpainting.evaluation.refinement import refine_predict
from lama.saicinpainting.training.trainers import load_checkpoint
from lama.saicinpainting.evaluation.utils import move_to_device


# Load the model
def get_inpaint_model():
    """
    Loads and initializes the inpainting model.
    Returns: Tuple of (model, predict_config)
    """
    predict_config = OmegaConf.load('./default.yaml')
    predict_config.model.path = './big-lama/models/'
    predict_config.refiner.gpu_ids = '0'

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    # Instead of setting device directly, we'll use it when loading the model
    predict_config.device = str(device)  # Store as string in config
    train_config_path = './big-lama/config.yaml'

    train_config = OmegaConf.load(train_config_path)
    train_config.training_model.predict_only = True
    train_config.visualizer.kind = 'noop'

    checkpoint_path = os.path.join(predict_config.model.path, 
                                   predict_config.model.checkpoint)

    model = load_checkpoint(train_config, checkpoint_path, strict=False, map_location=device)
    model.freeze()
    model.to(device)
    return model, predict_config

@spaces.GPU
def inpaint(input_dict, refinement_enabled=False):
    """
    Performs image inpainting on the input image using the provided mask.
    Args: input_dict containing 'background' (image) and 'layers' (mask)
    Returns: Tuple of (output_image, input_mask)
    """
    input_image = np.array(input_dict["background"].convert("RGB")).astype('float32') / 255
    input_mask = pil_to_binary_mask(input_dict['layers'][0])

    np_input_image = np.transpose(np.array(input_image), (2, 0, 1))
    np_input_mask = np.array(input_mask)[None, :, :]  # Add channel dimension for grayscale images
    batch = dict(image=np_input_image, mask=np_input_mask)

    inpaint_model, predict_config = get_inpaint_model()
    device = torch.device(predict_config.device)

    batch['unpad_to_size'] = [torch.tensor([batch['image'].shape[1]]),torch.tensor([batch['image'].shape[2]])]
    batch['image'] = torch.tensor(pad_img_to_modulo(batch['image'], predict_config.dataset.pad_out_to_modulo))[None].to(device)
    batch['mask'] = torch.tensor(pad_img_to_modulo(batch['mask'], predict_config.dataset.pad_out_to_modulo))[None].float().to(device)


    if refinement_enabled is True:
        cur_res = refine_predict(batch, inpaint_model, **predict_config.refiner)
        cur_res = cur_res[0].permute(1,2,0).detach().cpu().numpy()
    else:
        with torch.no_grad():
            batch = move_to_device(batch, device)
            batch['mask'] = (batch['mask'] > 0) * 1
            batch = inpaint_model(batch)
            cur_res = batch[predict_config.out_key][0].permute(1, 2, 0).detach().cpu().numpy()
            unpad_to_size = batch.get('unpad_to_size', None)
            if unpad_to_size is not None:
                orig_height, orig_width = unpad_to_size
                cur_res = cur_res[:orig_height, :orig_width]

    cur_res = np.clip(cur_res * 255, 0, 255).astype('uint8')
    output_image = Image.fromarray(cur_res)

    return output_image

def ceil_modulo(x, mod):
    if x % mod == 0:
        return x
    return (x // mod + 1) * mod

def pad_img_to_modulo(img, mod):
    channels, height, width = img.shape
    out_height = ceil_modulo(height, mod)
    out_width = ceil_modulo(width, mod)
    return np.pad(img, ((0, 0), (0, out_height - height), (0, out_width - width)), mode='symmetric')

def pil_to_binary_mask(pil_image, threshold=0, max_scale=1):
    """
    Converts a PIL image to a binary mask.

    Args:
        pil_image (PIL.Image): The input PIL image.
        threshold (int, optional): The threshold value for binarization. Defaults to 0.

    Returns:
        PIL.Image: A grayscale PIL image representing the binary mask.
    """
    np_image = np.array(pil_image)
    grayscale_image = Image.fromarray(np_image).convert("L")
    binary_mask = np.array(grayscale_image) > threshold
    mask = np.zeros(binary_mask.shape, dtype=np.uint8)
    for i in range(binary_mask.shape[0]):
        for j in range(binary_mask.shape[1]):
            if binary_mask[i,j] == True :
                mask[i,j] = 1
    mask = (mask*max_scale).astype(np.uint8)
    output_mask = Image.fromarray(mask)
    # Convert mask to grayscale
    return output_mask.convert("L")

css = ".output-image, .input-image, .image-preview {height: 600px !important}"

# Create Gradio interface
with gr.Blocks(css=css) as demo:
    gr.Markdown("# Image Inpainting")
    gr.Markdown("Upload an image and draw a mask to remove unwanted objects.")
    
    with gr.Row():
        input_image = gr.ImageEditor(type="pil", label='Input image & Mask', interactive=True, height="auto", width="auto", brush=gr.Brush(colors=['#f2e2cd'], default_size=25))
        output_image = gr.Image(type="pil", label="Output Image", height="auto", width="auto")
    
    with gr.Row():
        refine_checkbox = gr.Checkbox(label="Enable Refinement[SLOWER BUT BETTER]", value=False)
        inpaint_button = gr.Button("Inpaint")

    def inpaint_with_refinement(image, enable_refinement):
        return inpaint(image, refinement_enabled=enable_refinement)

    inpaint_button.click(
        fn=inpaint_with_refinement,
        inputs=[input_image, refine_checkbox],
        outputs=[output_image]
    )

# Launch the interface
if __name__ == "__main__":
    demo.launch()