File size: 11,803 Bytes
6a61d48
2a3b300
6a61d48
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4857a1c
 
6a61d48
4857a1c
6a61d48
 
 
 
 
9a7c58d
4857a1c
6a61d48
 
 
 
4857a1c
6a61d48
 
 
4857a1c
 
6a61d48
 
 
4857a1c
6a61d48
c32414c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6a61d48
c32414c
6a61d48
c32414c
 
 
 
 
 
6a61d48
 
4857a1c
 
6a61d48
 
4857a1c
6a61d48
 
 
 
 
 
 
4857a1c
 
6a61d48
4857a1c
6a61d48
 
eb476cc
42d6656
 
 
0ba9b7b
42d6656
 
b30b3e7
 
 
0ba9b7b
42d6656
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0ba9b7b
 
42d6656
0ba9b7b
31ad6fa
42d6656
0ba9b7b
 
162b5c5
0ba9b7b
42d6656
0ba9b7b
42d6656
 
0ba9b7b
42d6656
 
91be27b
2a3b300
91be27b
c09a6b8
91be27b
2a3b300
91be27b
 
2a3b300
91be27b
 
 
c09a6b8
 
6171288
91be27b
 
c09a6b8
91be27b
eb476cc
c09a6b8
 
 
2a3b300
c09a6b8
 
 
 
6171288
2a3b300
c09a6b8
 
 
 
 
 
 
91be27b
 
c09a6b8
91be27b
 
42d6656
2a3b300
 
 
 
 
 
 
cf689d2
 
 
 
2a3b300
 
 
 
601d327
2a3b300
 
 
601d327
2a3b300
 
cf689d2
31ad6fa
 
93228fa
31ad6fa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93228fa
 
 
 
 
7e6cd1e
 
93228fa
 
 
7e6cd1e
93228fa
 
31ad6fa
4857a1c
6a61d48
 
 
 
 
 
4857a1c
6a61d48
 
 
 
4857a1c
 
 
 
6a61d48
4857a1c
6a61d48
4857a1c
 
6a61d48
 
 
4857a1c
6a61d48
4857a1c
 
 
6a61d48
4857a1c
6a61d48
 
4857a1c
 
 
6a61d48
4857a1c
 
6a61d48
 
4857a1c
6a61d48
4857a1c
 
 
6a61d48
 
 
4857a1c
6a61d48
 
6f249a9
6a61d48
2a3b300
c666413
 
802830a
91be27b
802830a
 
aa6fbd2
7c3ff1d
91be27b
ce9792b
91be27b
c666413
e10b294
6171288
e10b294
802830a
e10b294
ce9792b
162b5c5
c666413
a7f9f53
6a61d48
f525870
 
6a61d48
 
 
 
fd093ba
f525870
6a61d48
 
 
 
4857a1c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
import cv2
from skimage.restoration import denoise_nl_means, estimate_sigma
import gradio as gr
import os
from PIL import Image
import numpy as np
import torch
from torch.autograd import Variable
from torchvision import transforms
import torch.nn.functional as F
import gdown
import matplotlib.pyplot as plt
import warnings
warnings.filterwarnings("ignore")

os.system("git clone https://github.com/xuebinqin/DIS")
os.system("mv DIS/IS-Net/* .")

# project imports
from data_loader_cache import normalize, im_reader, im_preprocess 
from models import *

#Helpers
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# Download official weights
if not os.path.exists("saved_models"):
    os.mkdir("saved_models")
    os.system("mv isnet.pth saved_models/")
    
class GOSNormalize(object):
    '''
    Normalize the Image using torch.transforms
    '''
    def __init__(self, mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225]):
        self.mean = mean
        self.std = std

    def __call__(self,image):
        image = normalize(image,self.mean,self.std)
        return image


transform =  transforms.Compose([GOSNormalize([0.5,0.5,0.5],[1.0,1.0,1.0])])

def load_image(image, hypar):
    """
    Load and preprocess an image.
    :param image: The image to load. This can be either a file path or a PIL.Image object.
    :param hypar: Hyperparameters for preprocessing.
    :return: A tuple of the preprocessed image tensor and its original shape.
    """
    # Check if the image is a file path or a PIL.Image object
    if isinstance(image, str):
        # If it's a file path, read the image from disk
        im = im_reader(image)
    elif isinstance(image, Image.Image):
        # If it's a PIL.Image object, convert it to a NumPy array
        im = np.array(image)
    else:
        raise TypeError("Unsupported image type")

    # Preprocess the image
    im, im_shp = im_preprocess(im, hypar["cache_size"])
    im = torch.divide(im, 255.0)
    shape = torch.from_numpy(np.array(im_shp))

    # Normalize and add batch dimension
    im = transform(im).unsqueeze(0)
    shape = shape.unsqueeze(0)  # Add batch dimension to shape

    return im, shape


def build_model(hypar,device):
    net = hypar["model"]#GOSNETINC(3,1)

    # convert to half precision
    if(hypar["model_digit"]=="half"):
        net.half()
        for layer in net.modules():
            if isinstance(layer, nn.BatchNorm2d):
                layer.float()

    net.to(device)

    if(hypar["restore_model"]!=""):
        net.load_state_dict(torch.load(hypar["model_path"]+"/"+hypar["restore_model"], map_location=device))
        net.to(device)
    net.eval()  
    return net


def crop_signature(original_image_path, mask, padding=32):
    # Convert the mask to a binary image
    _, binary_mask = cv2.threshold(mask, 127, 255, cv2.THRESH_BINARY)
    
    # Find contours from the binary mask
    contours, _ = cv2.findContours(binary_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)

    # Open the original image
    original_image = Image.open(original_image_path).convert("RGB")
    
    # If contours are found, proceed to crop
    if contours:
        # Find the combined bounding box of all contours
        min_x, min_y = original_image.width, original_image.height
        max_x = max_y = 0
        for contour in contours:
            x, y, w, h = cv2.boundingRect(contour)
            min_x, min_y = min(min_x, x), min(min_y, y)
            max_x, max_y = max(max_x, x + w), max(max_y, y + h)

        # Add padding to the bounding box
        x_padded = max(min_x - padding, 0)
        y_padded = max(min_y - padding, 0)
        w_padded = min(max_x + padding, original_image.width) - x_padded
        h_padded = min(max_y + padding, original_image.height) - y_padded

        # Crop the mask using the combined bounding box with padding
        cropped_mask = binary_mask[y_padded:y_padded+h_padded, x_padded:x_padded+w_padded]

        # Apply smoothing and denoising to the cropped mask
        smooth_denoised_mask = smooth_edges(cropped_mask)

        # Create an RGBA image with a black background and the denoised mask as the alpha channel
        mask_image = Image.new('RGBA', (w_padded, h_padded), (0, 0, 0))
        pil_mask = Image.fromarray(cropped_mask).convert('L')
        mask_image.putalpha(pil_mask)

        return mask_image

    # If no contours are found, return the original image
    return original_image

'''
def crop_signature(original_image_path, mask, padding=32):
    """
    Crop the signature from the original image using the provided mask.

    :param original_image_path: The file path of the original image.
    :param mask: The binary mask of the signature.
    :param padding: Padding to add around the bounding box of the signature.
    :return: Cropped image containing the signature.
    """
    # Convert the mask to a binary image
    _, binary_mask = cv2.threshold(mask, 127, 255, cv2.THRESH_BINARY)

    # Find contours from the binary mask
    contours, _ = cv2.findContours(binary_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)

    # Open the original image
    original_image = Image.open(original_image_path).convert("RGB")

    # If contours are found, proceed to crop
    if contours:
        # Find the combined bounding box of all contours
        min_x, min_y = original_image.width, original_image.height
        max_x = max_y = 0

        for contour in contours:
            x, y, w, h = cv2.boundingRect(contour)
            min_x, min_y = min(min_x, x), min(min_y, y)
            max_x, max_y = max(max_x, x + w), max(max_y, y + h)

        # Add padding to the bounding box
        x_padded = max(min_x - padding, 0)
        y_padded = max(min_y - padding, 0)
        w_padded = min(max_x + padding, original_image.width) - x_padded
        h_padded = min(max_y + padding, original_image.height) - y_padded

        # Crop the original image using the combined bounding box with padding
        cropped_image = original_image.crop((x_padded, y_padded, x_padded + w_padded, y_padded + h_padded))

        return cropped_image

    # If no contours are found, return the original image
    return original_image
'''

def smooth_and_denoise(mask):
    """
    Apply smoothing and denoising to the mask.
    :param mask: The binary mask of the signature.
    :return: Processed mask.
    """
    # Ensure the mask is a 2D array
    if mask.ndim > 2:
        mask = mask[..., 0]

    # Apply Gaussian Blurring for smoothing
    smoothed_mask = cv2.GaussianBlur(mask, (5, 5), 0)

    # Estimate noise standard deviation from the image
    sigma_est = np.mean(estimate_sigma(smoothed_mask, channel_axis=None))

    # Apply Non-Local Means Denoising
    denoised_mask = denoise_nl_means(smoothed_mask, h=1.15 * sigma_est, fast_mode=True,
                                     patch_size=5, patch_distance=3, channel_axis=None)
    return denoised_mask


def smooth_edges(mask):
    """
    Smooth edges of a binary mask using morphological operations and anti-aliasing resizing.
    :param mask: The binary mask of the signature.
    :return: Mask with smoothed edges.
    """
    # Convert mask to uint8 type if it isn't already
    mask = mask.astype(np.uint8)
    
    # Define a kernel for morphological operations
    kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3))
    
    # Use morphological close operation to close small holes in the mask
    closing = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel, iterations=2)
    
    # Use morphological open operation to remove noise
    opening = cv2.morphologyEx(closing, cv2.MORPH_OPEN, kernel, iterations=2)
    
    # Dilate the mask to make the signature slightly thicker
    dilated = cv2.dilate(opening, kernel, iterations=1)
    
    # Convert dilated mask to a PIL image for anti-aliasing resizing
    pil_mask = Image.fromarray(dilated)
    
    # Resize the image to a smaller size, then back to the original size
    small_size = (pil_mask.width // 2, pil_mask.height // 2)
    pil_mask_small = pil_mask.resize(small_size, Image.Resampling.LANCZOS)
    pil_mask_smooth = pil_mask_small.resize(pil_mask.size, Image.Resampling.LANCZOS)
    
    # Convert back to a numpy array
    smoothed_mask = np.array(pil_mask_smooth)
    final_mask = cv2.bilateralFilter(smoothed_mask, d=9, sigmaColor=75, sigmaSpace=75)

    return final_mask

    
def predict(net,  inputs_val, shapes_val, hypar, device):
    '''
    Given an Image, predict the mask
    '''
    net.eval()

    if(hypar["model_digit"]=="full"):
        inputs_val = inputs_val.type(torch.FloatTensor)
    else:
        inputs_val = inputs_val.type(torch.HalfTensor)

  
    inputs_val_v = Variable(inputs_val, requires_grad=False).to(device) # wrap inputs in Variable
   
    ds_val = net(inputs_val_v)[0] # list of 6 results

    pred_val = ds_val[0][0,:,:,:] # B x 1 x H x W    # we want the first one which is the most accurate prediction

    ## recover the prediction spatial size to the orignal image size
    pred_val = torch.squeeze(F.upsample(torch.unsqueeze(pred_val,0),(shapes_val[0][0],shapes_val[0][1]),mode='bilinear'))

    ma = torch.max(pred_val)
    mi = torch.min(pred_val)
    pred_val = (pred_val-mi)/(ma-mi) # max = 1

    if device == 'cuda': torch.cuda.empty_cache()
    return (pred_val.detach().cpu().numpy()*255).astype(np.uint8) # it is the mask we need
    
# Set Parameters
hypar = {} # paramters for inferencing


hypar["model_path"] ="./saved_models" ## load trained weights from this path
hypar["restore_model"] = "isnet.pth" ## name of the to-be-loaded weights
hypar["interm_sup"] = False ## indicate if activate intermediate feature supervision

##  choose floating point accuracy --
hypar["model_digit"] = "full" ## indicates "half" or "full" accuracy of float number
hypar["seed"] = 0

hypar["cache_size"] = [1024, 1024] ## cached input spatial resolution, can be configured into different size

## data augmentation parameters ---
hypar["input_size"] = [1024, 1024] ## mdoel input spatial size, usually use the same value hypar["cache_size"], which means we don't further resize the images
hypar["crop_size"] = [1024, 1024] ## random crop size from the input, it is usually set as smaller than hypar["cache_size"], e.g., [920,920] for data augmentation

hypar["model"] = ISNetDIS()

 # Build Model
net = build_model(hypar, device)


def inference(image):
    image_path = image
    
    image_tensor, orig_size = load_image(image_path, hypar) 
    original_mask = predict(net, image_tensor, orig_size, hypar, device)
     
    # Process the original mask with smoothing and denoising
    processed_mask = smooth_and_denoise(original_mask)

    # Convert processed mask to PIL image
    pil_processed_mask = Image.fromarray((processed_mask * 255).astype(np.uint8)).convert('L')
    pil_original_mask = Image.fromarray(original_mask).convert('L')
    
    im_rgb = Image.open(image).convert("RGB")
    im_dark = Image.new('RGB', im_rgb.size, (0, 0, 0))
    cropped_signature_image = crop_signature(image_path, original_mask, 64)
    
    # Apply processed mask to images
    im_rgba = im_rgb.copy()
    im_rgba.putalpha(pil_original_mask)
    im_dark.putalpha(pil_original_mask)
    
    return [cropped_signature_image, im_rgba, im_dark]

title = "Mysign.id - Signature Background removal based on DIS"
description = "ML Model based on ECCV2022/dis-background-removal specifically made for removing background from signatures."

interface = gr.Interface(
    fn=inference,
    inputs=gr.Image(type='filepath'),
    outputs=["image", "image", "image"],
    examples=[['example-1.jpg'], ['example-2.jpg']],
    title=title,
    description=description,
    allow_flagging='never',
    cache_examples=False,
    ).queue(api_open=True).launch(show_api=True, show_error=True)