Yogesh Kumar commited on
Commit
2a3b300
·
1 Parent(s): c1fc1b8

Fix for cropped image

Browse files
Files changed (1) hide show
  1. app.py +118 -63
app.py CHANGED
@@ -1,4 +1,7 @@
 
 
1
  import cv2
 
2
  import gradio as gr
3
  import os
4
  from PIL import Image
@@ -16,45 +19,48 @@ os.system("git clone https://github.com/xuebinqin/DIS")
16
  os.system("mv DIS/IS-Net/* .")
17
 
18
  # project imports
19
- from data_loader_cache import normalize, im_reader, im_preprocess
20
- from models import *
21
 
22
- #Helpers
23
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
24
 
25
  # Download official weights
26
  if not os.path.exists("saved_models"):
27
  os.mkdir("saved_models")
28
  os.system("mv isnet.pth saved_models/")
29
-
 
30
  class GOSNormalize(object):
31
  '''
32
  Normalize the Image using torch.transforms
33
  '''
34
- def __init__(self, mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225]):
 
35
  self.mean = mean
36
  self.std = std
37
 
38
- def __call__(self,image):
39
- image = normalize(image,self.mean,self.std)
40
  return image
41
 
42
 
43
- transform = transforms.Compose([GOSNormalize([0.5,0.5,0.5],[1.0,1.0,1.0])])
 
 
44
 
45
  def load_image(im_path, hypar):
46
  im = im_reader(im_path)
47
  im, im_shp = im_preprocess(im, hypar["cache_size"])
48
- im = torch.divide(im,255.0)
49
  shape = torch.from_numpy(np.array(im_shp))
50
- return transform(im).unsqueeze(0), shape.unsqueeze(0) # make a batch of image, shape
 
51
 
52
 
53
- def build_model(hypar,device):
54
- net = hypar["model"]#GOSNETINC(3,1)
55
 
56
  # convert to half precision
57
- if(hypar["model_digit"]=="half"):
58
  net.half()
59
  for layer in net.modules():
60
  if isinstance(layer, nn.BatchNorm2d):
@@ -62,96 +68,145 @@ def build_model(hypar,device):
62
 
63
  net.to(device)
64
 
65
- if(hypar["restore_model"]!=""):
66
- net.load_state_dict(torch.load(hypar["model_path"]+"/"+hypar["restore_model"], map_location=device))
 
67
  net.to(device)
68
- net.eval()
69
  return net
70
 
71
 
72
- def crop_to_signature(mask):
73
- contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
 
 
 
 
 
 
 
74
  if contours:
75
  # Assume the largest contour is the signature
76
  x, y, w, h = cv2.boundingRect(max(contours, key=cv2.contourArea))
77
- # Add some padding to the bounding box
78
- padding = 32
79
- x, y, w, h = x-padding, y-padding, w+2*padding, h+2*padding
 
 
 
 
80
  # Crop the mask
81
  cropped_mask = mask[y:y+h, x:x+w]
82
- return cropped_mask
83
  else:
84
- return mask # Return the original mask if no contours are found
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
 
86
-
87
  def predict(net, inputs_val, shapes_val, hypar, device):
88
  '''
89
  Given an Image, predict the mask
90
  '''
91
  net.eval()
92
 
93
- if(hypar["model_digit"]=="full"):
94
  inputs_val = inputs_val.type(torch.FloatTensor)
95
  else:
96
  inputs_val = inputs_val.type(torch.HalfTensor)
97
 
98
-
99
- inputs_val_v = Variable(inputs_val, requires_grad=False).to(device) # wrap inputs in Variable
100
-
101
- ds_val = net(inputs_val_v)[0] # list of 6 results
102
 
103
- pred_val = ds_val[0][0,:,:,:] # B x 1 x H x W # we want the first one which is the most accurate prediction
 
104
 
105
- ## recover the prediction spatial size to the orignal image size
106
- pred_val = torch.squeeze(F.upsample(torch.unsqueeze(pred_val,0),(shapes_val[0][0],shapes_val[0][1]),mode='bilinear'))
 
107
 
108
  ma = torch.max(pred_val)
109
  mi = torch.min(pred_val)
110
- pred_val = (pred_val-mi)/(ma-mi) # max = 1
 
 
 
 
 
 
111
 
112
- if device == 'cuda': torch.cuda.empty_cache()
113
- return (pred_val.detach().cpu().numpy()*255).astype(np.uint8) # it is the mask we need
114
-
115
  # Set Parameters
116
- hypar = {} # paramters for inferencing
117
 
118
 
119
- hypar["model_path"] ="./saved_models" ## load trained weights from this path
120
- hypar["restore_model"] = "isnet.pth" ## name of the to-be-loaded weights
121
- hypar["interm_sup"] = False ## indicate if activate intermediate feature supervision
 
122
 
123
- ## choose floating point accuracy --
124
- hypar["model_digit"] = "full" ## indicates "half" or "full" accuracy of float number
 
125
  hypar["seed"] = 0
126
 
127
- hypar["cache_size"] = [1024, 1024] ## cached input spatial resolution, can be configured into different size
 
128
 
129
- ## data augmentation parameters ---
130
- hypar["input_size"] = [1024, 1024] ## mdoel input spatial size, usually use the same value hypar["cache_size"], which means we don't further resize the images
131
- hypar["crop_size"] = [1024, 1024] ## random crop size from the input, it is usually set as smaller than hypar["cache_size"], e.g., [920,920] for data augmentation
 
 
132
 
133
  hypar["model"] = ISNetDIS()
134
 
135
- # Build Model
136
  net = build_model(hypar, device)
137
 
138
 
139
  def inference(image):
140
- image_path = image
141
-
142
- image_tensor, orig_size = load_image(image_path, hypar)
143
- mask = predict(net, image_tensor, orig_size, hypar, device)
144
-
145
- pil_mask = Image.fromarray(mask).convert('L')
146
- im_rgb = Image.open(image).convert("RGB")
147
- im_dark = Image.new('RGB', im_rgb.size, (0, 0, 0))
148
-
149
-
150
- im_rgba = im_rgb.copy()
151
- im_rgba.putalpha(pil_mask)
152
- im_dark.putalpha(pil_mask)
153
-
154
- return [im_rgba, pil_mask, im_dark]
 
 
 
 
 
 
 
 
155
 
156
 
157
  title = "Mysign.id - Signature Background removal based on DIS"
@@ -166,4 +221,4 @@ interface = gr.Interface(
166
  description=description,
167
  allow_flagging='never',
168
  cache_examples=False,
169
- ).queue(api_open=True).launch(show_api=True, show_error=True)
 
1
+ from data_loader_cache import normalize, im_reader, im_preprocess
2
+ from models import *
3
  import cv2
4
+ from skimage.restoration import denoise_nl_means, estimate_sigma
5
  import gradio as gr
6
  import os
7
  from PIL import Image
 
19
  os.system("mv DIS/IS-Net/* .")
20
 
21
  # project imports
 
 
22
 
23
+ # Helpers
24
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
25
 
26
  # Download official weights
27
  if not os.path.exists("saved_models"):
28
  os.mkdir("saved_models")
29
  os.system("mv isnet.pth saved_models/")
30
+
31
+
32
  class GOSNormalize(object):
33
  '''
34
  Normalize the Image using torch.transforms
35
  '''
36
+
37
+ def __init__(self, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]):
38
  self.mean = mean
39
  self.std = std
40
 
41
+ def __call__(self, image):
42
+ image = normalize(image, self.mean, self.std)
43
  return image
44
 
45
 
46
+ transform = transforms.Compose(
47
+ [GOSNormalize([0.5, 0.5, 0.5], [1.0, 1.0, 1.0])])
48
+
49
 
50
  def load_image(im_path, hypar):
51
  im = im_reader(im_path)
52
  im, im_shp = im_preprocess(im, hypar["cache_size"])
53
+ im = torch.divide(im, 255.0)
54
  shape = torch.from_numpy(np.array(im_shp))
55
+ # make a batch of image, shape
56
+ return transform(im).unsqueeze(0), shape.unsqueeze(0)
57
 
58
 
59
+ def build_model(hypar, device):
60
+ net = hypar["model"] # GOSNETINC(3,1)
61
 
62
  # convert to half precision
63
+ if (hypar["model_digit"] == "half"):
64
  net.half()
65
  for layer in net.modules():
66
  if isinstance(layer, nn.BatchNorm2d):
 
68
 
69
  net.to(device)
70
 
71
+ if (hypar["restore_model"] != ""):
72
+ net.load_state_dict(torch.load(
73
+ hypar["model_path"]+"/"+hypar["restore_model"], map_location=device))
74
  net.to(device)
75
+ net.eval()
76
  return net
77
 
78
 
79
+ def crop_to_signature(mask, padding=32):
80
+ """
81
+ Crop the signature area based on the mask and add padding.
82
+ :param mask: The binary mask of the signature.
83
+ :param padding: Padding around the cropped signature.
84
+ :return: Cropped mask of the signature with padding.
85
+ """
86
+ contours, _ = cv2.findContours(
87
+ mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
88
  if contours:
89
  # Assume the largest contour is the signature
90
  x, y, w, h = cv2.boundingRect(max(contours, key=cv2.contourArea))
91
+
92
+ # Add padding to the bounding box
93
+ x = max(x - padding, 0)
94
+ y = max(y - padding, 0)
95
+ w = min(w + 2 * padding, mask.shape[1] - x)
96
+ h = min(h + 2 * padding, mask.shape[0] - y)
97
+
98
  # Crop the mask
99
  cropped_mask = mask[y:y+h, x:x+w]
100
+ return cropped_mask, (x, y, w, h)
101
  else:
102
+ # Return the original mask if no contours are found
103
+ return mask, (0, 0, mask.shape[1], mask.shape[0])
104
+
105
+
106
+ def smooth_and_denoise(mask):
107
+ """
108
+ Apply smoothing and denoising to the mask.
109
+ :param mask: The binary mask of the signature.
110
+ :return: Processed mask.
111
+ """
112
+ # Apply Gaussian Blurring for smoothing
113
+ smoothed_mask = cv2.GaussianBlur(mask, (5, 5), 0)
114
+
115
+ # Estimate noise standard deviation from the image
116
+ sigma_est = np.mean(estimate_sigma(smoothed_mask, multichannel=True))
117
+
118
+ # Apply Non-Local Means Denoising
119
+ denoised_mask = denoise_nl_means(smoothed_mask, h=1.15 * sigma_est, fast_mode=True,
120
+ patch_size=5, patch_distance=3, multichannel=True)
121
+ return denoised_mask
122
+
123
 
 
124
  def predict(net, inputs_val, shapes_val, hypar, device):
125
  '''
126
  Given an Image, predict the mask
127
  '''
128
  net.eval()
129
 
130
+ if (hypar["model_digit"] == "full"):
131
  inputs_val = inputs_val.type(torch.FloatTensor)
132
  else:
133
  inputs_val = inputs_val.type(torch.HalfTensor)
134
 
135
+ inputs_val_v = Variable(inputs_val, requires_grad=False).to(
136
+ device) # wrap inputs in Variable
137
+
138
+ ds_val = net(inputs_val_v)[0] # list of 6 results
139
 
140
+ # B x 1 x H x W # we want the first one which is the most accurate prediction
141
+ pred_val = ds_val[0][0, :, :, :]
142
 
143
+ # recover the prediction spatial size to the orignal image size
144
+ pred_val = torch.squeeze(F.upsample(torch.unsqueeze(
145
+ pred_val, 0), (shapes_val[0][0], shapes_val[0][1]), mode='bilinear'))
146
 
147
  ma = torch.max(pred_val)
148
  mi = torch.min(pred_val)
149
+ pred_val = (pred_val-mi)/(ma-mi) # max = 1
150
+
151
+ if device == 'cuda':
152
+ torch.cuda.empty_cache()
153
+ # it is the mask we need
154
+ return (pred_val.detach().cpu().numpy()*255).astype(np.uint8)
155
+
156
 
 
 
 
157
  # Set Parameters
158
+ hypar = {} # paramters for inferencing
159
 
160
 
161
+ hypar["model_path"] = "./saved_models" # load trained weights from this path
162
+ hypar["restore_model"] = "isnet.pth" # name of the to-be-loaded weights
163
+ # indicate if activate intermediate feature supervision
164
+ hypar["interm_sup"] = False
165
 
166
+ # choose floating point accuracy --
167
+ # indicates "half" or "full" accuracy of float number
168
+ hypar["model_digit"] = "full"
169
  hypar["seed"] = 0
170
 
171
+ # cached input spatial resolution, can be configured into different size
172
+ hypar["cache_size"] = [1024, 1024]
173
 
174
+ # data augmentation parameters ---
175
+ # mdoel input spatial size, usually use the same value hypar["cache_size"], which means we don't further resize the images
176
+ hypar["input_size"] = [1024, 1024]
177
+ # random crop size from the input, it is usually set as smaller than hypar["cache_size"], e.g., [920,920] for data augmentation
178
+ hypar["crop_size"] = [1024, 1024]
179
 
180
  hypar["model"] = ISNetDIS()
181
 
182
+ # Build Model
183
  net = build_model(hypar, device)
184
 
185
 
186
  def inference(image):
187
+ image_path = image
188
+
189
+ image_tensor, orig_size = load_image(image_path, hypar)
190
+ mask = predict(net, image_tensor, orig_size, hypar, device)
191
+
192
+ cropped_mask, (x, y, w, h) = crop_to_signature(mask)
193
+ processed_mask = smooth_and_denoise(cropped_mask)
194
+
195
+ pil_mask = Image.fromarray(processed_mask).convert('L')
196
+ im_rgb = Image.open(image).convert("RGB")
197
+
198
+ # Crop the original image to match the mask
199
+ cropped_image = im_rgb.crop((x, y, x + w, y + h))
200
+
201
+ # Create a dark background image of the same size as the cropped image
202
+ im_dark = Image.new('RGB', (w, h), (0, 0, 0))
203
+
204
+ # Apply the mask to the cropped image
205
+ im_rgba = cropped_image.copy()
206
+ im_rgba.putalpha(pil_mask)
207
+ im_dark.putalpha(pil_mask)
208
+
209
+ return [im_rgba, pil_mask, im_dark]
210
 
211
 
212
  title = "Mysign.id - Signature Background removal based on DIS"
 
221
  description=description,
222
  allow_flagging='never',
223
  cache_examples=False,
224
+ ).queue(api_open=True).launch(show_api=True, show_error=True)