mysign.id / app.py
itsyogesh's picture
Remove smoothning
162b5c5 verified
import cv2
from skimage.restoration import denoise_nl_means, estimate_sigma
import gradio as gr
import os
from PIL import Image
import numpy as np
import torch
from torch.autograd import Variable
from torchvision import transforms
import torch.nn.functional as F
import gdown
import matplotlib.pyplot as plt
import warnings
warnings.filterwarnings("ignore")
os.system("git clone https://github.com/xuebinqin/DIS")
os.system("mv DIS/IS-Net/* .")
# project imports
from data_loader_cache import normalize, im_reader, im_preprocess
from models import *
#Helpers
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# Download official weights
if not os.path.exists("saved_models"):
os.mkdir("saved_models")
os.system("mv isnet.pth saved_models/")
class GOSNormalize(object):
'''
Normalize the Image using torch.transforms
'''
def __init__(self, mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225]):
self.mean = mean
self.std = std
def __call__(self,image):
image = normalize(image,self.mean,self.std)
return image
transform = transforms.Compose([GOSNormalize([0.5,0.5,0.5],[1.0,1.0,1.0])])
def load_image(image, hypar):
"""
Load and preprocess an image.
:param image: The image to load. This can be either a file path or a PIL.Image object.
:param hypar: Hyperparameters for preprocessing.
:return: A tuple of the preprocessed image tensor and its original shape.
"""
# Check if the image is a file path or a PIL.Image object
if isinstance(image, str):
# If it's a file path, read the image from disk
im = im_reader(image)
elif isinstance(image, Image.Image):
# If it's a PIL.Image object, convert it to a NumPy array
im = np.array(image)
else:
raise TypeError("Unsupported image type")
# Preprocess the image
im, im_shp = im_preprocess(im, hypar["cache_size"])
im = torch.divide(im, 255.0)
shape = torch.from_numpy(np.array(im_shp))
# Normalize and add batch dimension
im = transform(im).unsqueeze(0)
shape = shape.unsqueeze(0) # Add batch dimension to shape
return im, shape
def build_model(hypar,device):
net = hypar["model"]#GOSNETINC(3,1)
# convert to half precision
if(hypar["model_digit"]=="half"):
net.half()
for layer in net.modules():
if isinstance(layer, nn.BatchNorm2d):
layer.float()
net.to(device)
if(hypar["restore_model"]!=""):
net.load_state_dict(torch.load(hypar["model_path"]+"/"+hypar["restore_model"], map_location=device))
net.to(device)
net.eval()
return net
def crop_signature(original_image_path, mask, padding=32):
# Convert the mask to a binary image
_, binary_mask = cv2.threshold(mask, 127, 255, cv2.THRESH_BINARY)
# Find contours from the binary mask
contours, _ = cv2.findContours(binary_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
# Open the original image
original_image = Image.open(original_image_path).convert("RGB")
# If contours are found, proceed to crop
if contours:
# Find the combined bounding box of all contours
min_x, min_y = original_image.width, original_image.height
max_x = max_y = 0
for contour in contours:
x, y, w, h = cv2.boundingRect(contour)
min_x, min_y = min(min_x, x), min(min_y, y)
max_x, max_y = max(max_x, x + w), max(max_y, y + h)
# Add padding to the bounding box
x_padded = max(min_x - padding, 0)
y_padded = max(min_y - padding, 0)
w_padded = min(max_x + padding, original_image.width) - x_padded
h_padded = min(max_y + padding, original_image.height) - y_padded
# Crop the mask using the combined bounding box with padding
cropped_mask = binary_mask[y_padded:y_padded+h_padded, x_padded:x_padded+w_padded]
# Apply smoothing and denoising to the cropped mask
smooth_denoised_mask = smooth_edges(cropped_mask)
# Create an RGBA image with a black background and the denoised mask as the alpha channel
mask_image = Image.new('RGBA', (w_padded, h_padded), (0, 0, 0))
pil_mask = Image.fromarray(cropped_mask).convert('L')
mask_image.putalpha(pil_mask)
return mask_image
# If no contours are found, return the original image
return original_image
'''
def crop_signature(original_image_path, mask, padding=32):
"""
Crop the signature from the original image using the provided mask.
:param original_image_path: The file path of the original image.
:param mask: The binary mask of the signature.
:param padding: Padding to add around the bounding box of the signature.
:return: Cropped image containing the signature.
"""
# Convert the mask to a binary image
_, binary_mask = cv2.threshold(mask, 127, 255, cv2.THRESH_BINARY)
# Find contours from the binary mask
contours, _ = cv2.findContours(binary_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
# Open the original image
original_image = Image.open(original_image_path).convert("RGB")
# If contours are found, proceed to crop
if contours:
# Find the combined bounding box of all contours
min_x, min_y = original_image.width, original_image.height
max_x = max_y = 0
for contour in contours:
x, y, w, h = cv2.boundingRect(contour)
min_x, min_y = min(min_x, x), min(min_y, y)
max_x, max_y = max(max_x, x + w), max(max_y, y + h)
# Add padding to the bounding box
x_padded = max(min_x - padding, 0)
y_padded = max(min_y - padding, 0)
w_padded = min(max_x + padding, original_image.width) - x_padded
h_padded = min(max_y + padding, original_image.height) - y_padded
# Crop the original image using the combined bounding box with padding
cropped_image = original_image.crop((x_padded, y_padded, x_padded + w_padded, y_padded + h_padded))
return cropped_image
# If no contours are found, return the original image
return original_image
'''
def smooth_and_denoise(mask):
"""
Apply smoothing and denoising to the mask.
:param mask: The binary mask of the signature.
:return: Processed mask.
"""
# Ensure the mask is a 2D array
if mask.ndim > 2:
mask = mask[..., 0]
# Apply Gaussian Blurring for smoothing
smoothed_mask = cv2.GaussianBlur(mask, (5, 5), 0)
# Estimate noise standard deviation from the image
sigma_est = np.mean(estimate_sigma(smoothed_mask, channel_axis=None))
# Apply Non-Local Means Denoising
denoised_mask = denoise_nl_means(smoothed_mask, h=1.15 * sigma_est, fast_mode=True,
patch_size=5, patch_distance=3, channel_axis=None)
return denoised_mask
def smooth_edges(mask):
"""
Smooth edges of a binary mask using morphological operations and anti-aliasing resizing.
:param mask: The binary mask of the signature.
:return: Mask with smoothed edges.
"""
# Convert mask to uint8 type if it isn't already
mask = mask.astype(np.uint8)
# Define a kernel for morphological operations
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3))
# Use morphological close operation to close small holes in the mask
closing = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel, iterations=2)
# Use morphological open operation to remove noise
opening = cv2.morphologyEx(closing, cv2.MORPH_OPEN, kernel, iterations=2)
# Dilate the mask to make the signature slightly thicker
dilated = cv2.dilate(opening, kernel, iterations=1)
# Convert dilated mask to a PIL image for anti-aliasing resizing
pil_mask = Image.fromarray(dilated)
# Resize the image to a smaller size, then back to the original size
small_size = (pil_mask.width // 2, pil_mask.height // 2)
pil_mask_small = pil_mask.resize(small_size, Image.Resampling.LANCZOS)
pil_mask_smooth = pil_mask_small.resize(pil_mask.size, Image.Resampling.LANCZOS)
# Convert back to a numpy array
smoothed_mask = np.array(pil_mask_smooth)
final_mask = cv2.bilateralFilter(smoothed_mask, d=9, sigmaColor=75, sigmaSpace=75)
return final_mask
def predict(net, inputs_val, shapes_val, hypar, device):
'''
Given an Image, predict the mask
'''
net.eval()
if(hypar["model_digit"]=="full"):
inputs_val = inputs_val.type(torch.FloatTensor)
else:
inputs_val = inputs_val.type(torch.HalfTensor)
inputs_val_v = Variable(inputs_val, requires_grad=False).to(device) # wrap inputs in Variable
ds_val = net(inputs_val_v)[0] # list of 6 results
pred_val = ds_val[0][0,:,:,:] # B x 1 x H x W # we want the first one which is the most accurate prediction
## recover the prediction spatial size to the orignal image size
pred_val = torch.squeeze(F.upsample(torch.unsqueeze(pred_val,0),(shapes_val[0][0],shapes_val[0][1]),mode='bilinear'))
ma = torch.max(pred_val)
mi = torch.min(pred_val)
pred_val = (pred_val-mi)/(ma-mi) # max = 1
if device == 'cuda': torch.cuda.empty_cache()
return (pred_val.detach().cpu().numpy()*255).astype(np.uint8) # it is the mask we need
# Set Parameters
hypar = {} # paramters for inferencing
hypar["model_path"] ="./saved_models" ## load trained weights from this path
hypar["restore_model"] = "isnet.pth" ## name of the to-be-loaded weights
hypar["interm_sup"] = False ## indicate if activate intermediate feature supervision
## choose floating point accuracy --
hypar["model_digit"] = "full" ## indicates "half" or "full" accuracy of float number
hypar["seed"] = 0
hypar["cache_size"] = [1024, 1024] ## cached input spatial resolution, can be configured into different size
## data augmentation parameters ---
hypar["input_size"] = [1024, 1024] ## mdoel input spatial size, usually use the same value hypar["cache_size"], which means we don't further resize the images
hypar["crop_size"] = [1024, 1024] ## random crop size from the input, it is usually set as smaller than hypar["cache_size"], e.g., [920,920] for data augmentation
hypar["model"] = ISNetDIS()
# Build Model
net = build_model(hypar, device)
def inference(image):
image_path = image
image_tensor, orig_size = load_image(image_path, hypar)
original_mask = predict(net, image_tensor, orig_size, hypar, device)
# Process the original mask with smoothing and denoising
processed_mask = smooth_and_denoise(original_mask)
# Convert processed mask to PIL image
pil_processed_mask = Image.fromarray((processed_mask * 255).astype(np.uint8)).convert('L')
pil_original_mask = Image.fromarray(original_mask).convert('L')
im_rgb = Image.open(image).convert("RGB")
im_dark = Image.new('RGB', im_rgb.size, (0, 0, 0))
cropped_signature_image = crop_signature(image_path, original_mask, 64)
# Apply processed mask to images
im_rgba = im_rgb.copy()
im_rgba.putalpha(pil_original_mask)
im_dark.putalpha(pil_original_mask)
return [cropped_signature_image, im_rgba, im_dark]
title = "Mysign.id - Signature Background removal based on DIS"
description = "ML Model based on ECCV2022/dis-background-removal specifically made for removing background from signatures."
interface = gr.Interface(
fn=inference,
inputs=gr.Image(type='filepath'),
outputs=["image", "image", "image"],
examples=[['example-1.jpg'], ['example-2.jpg']],
title=title,
description=description,
allow_flagging='never',
cache_examples=False,
).queue(api_open=True).launch(show_api=True, show_error=True)