| from hashlib import sha1 |
| from pathlib import Path |
|
|
| import cv2 |
| import gradio as gr |
| import numpy as np |
| from PIL import Image |
| import PIL |
| import torch |
| from torchvision import transforms |
| import torch.nn.functional as F |
|
|
|
|
| def estimate_foreground_ml(image, alpha, return_background=False): |
| """ |
| Estimates the foreground and background of an image based on an alpha mask. |
| |
| Parameters: |
| - image: numpy array of shape (H, W, 3), the input RGB image. |
| - alpha: numpy array of shape (H, W), the alpha mask with values ranging from 0 to 1. |
| - return_background: boolean, if True, both foreground and background are returned. |
| |
| Returns: |
| - If return_background is False, returns only the foreground. |
| - If return_background is True, returns a tuple (foreground, background). |
| """ |
|
|
| |
| |
| foreground = image * alpha |
|
|
| if return_background: |
| |
| |
| background_alpha = 1 - alpha |
| |
| background = (image * background_alpha) + (1 - background_alpha) * 255 |
|
|
| return foreground, background |
|
|
| return foreground |
|
|
|
|
| def load_entire_model(taskname): |
| model_ls = [] |
| if (taskname == "mask"): |
| model = torch.jit.load(Path("./models/sod.pt")) |
| model.eval() |
| model_ls.append(model) |
| elif (taskname == "matting"): |
| model = torch.jit.load(Path("./models/trimap.pt")) |
| model.eval() |
| model_ls.append(model) |
|
|
| model = torch.jit.load(Path("./models/matting.pt")) |
| model.eval() |
| model_ls.append(model) |
| else: |
| model_ls = [] |
|
|
| return model_ls |
|
|
|
|
| model_names = [ |
| "matting", |
| "mask" |
| ] |
| model_dict = { |
| name: None |
| for name in model_names |
| } |
|
|
| last_result = { |
| "cache_key": None, |
| "algorithm": None, |
| } |
|
|
|
|
| def image_matting( |
| image: PIL.Image.Image, |
| result_type: str, |
| bg_color: str, |
| algorithm: str, |
| morph_op: str, |
| morph_op_factor: float, |
| ) -> np.ndarray: |
| image_np = np.ascontiguousarray(image) |
| width, height = image_np.shape[1], image_np.shape[0] |
| cache_key = sha1(image_np).hexdigest() |
| if cache_key == last_result["cache_key"] and algorithm == last_result["algorithm"]: |
| alpha = last_result["alpha"] |
| else: |
| model = load_entire_model(algorithm) |
| transform = transforms.Compose([ |
| |
| transforms.Resize((798, 798)), |
| transforms.ToTensor(), |
| transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) |
| ]) |
| if (algorithm == "mask"): |
| input_tensor = transform(image).unsqueeze(0) |
| with torch.no_grad(): |
| alpha = model[0](input_tensor).float() |
| alpha = F.interpolate(alpha, [height, width], mode="bilinear") |
| alpha = np.array(alpha* 255.).astype(np.uint8)[0][0] |
| alpha = np.stack((alpha,alpha,alpha),axis=2) |
| else: |
| transform2 = transforms.Compose([ |
| transforms.Resize((800, 800)), |
| transforms.ToTensor(), |
| |
| ]) |
|
|
| input_tensor = transform(image).unsqueeze(0) |
| with torch.no_grad(): |
| output = model[0](input_tensor).float() |
| output = F.interpolate(output, [height, width], mode="bilinear") |
|
|
| trimap = np.array(output[0][0]) |
|
|
| ratio = 0.05 |
| site = np.where(trimap > 0) |
| try: |
| bbox = [np.min(site[1]), np.min(site[0]), np.max(site[1]), np.max(site[0])] |
| except: |
| bbox = [0, 0, width, height] |
|
|
| x0, y0, x1, y1 = bbox |
| H = y1 - y0 |
| W = x1 - x0 |
| x0 = int(max(0, x0 - ratio * W)) |
| x1 = int(min(width, x1 + ratio * W) ) |
| y0 = int(max(0, y0 - ratio * H) ) |
| y1 = int(min(height, y1 + ratio * H) ) |
|
|
| Image_input = image.crop((x0, y0, x1, y1)) |
| |
| input_tensor = transform2(Image_input).unsqueeze(0) |
|
|
| trimap = trimap[y0:y1, x0:x1] |
| trimap = np.where(trimap < 1, 0, trimap) |
| trimap = np.where(trimap > 1, 255, trimap) |
| trimap = np.where(trimap == 1, 128, trimap) |
| |
|
|
| trimap = Image.fromarray(np.uint8(trimap)).convert('L') |
| input_tensor2 = transform2(trimap).unsqueeze(0) |
| with torch.no_grad(): |
| output = model[1]({'image': input_tensor, 'trimap': input_tensor2})['phas'] |
| output = F.interpolate(output, [Image_input.size[1],Image_input.size[0]], mode="bilinear")[0].numpy() |
|
|
| numpy_image = (output * 255.).astype(np.uint8) |
|
|
| |
| numpy_image = numpy_image.squeeze(0) |
| pil_image = Image.fromarray(numpy_image, mode='L') |
| alpha = Image.new(mode='RGB', size=image.size) |
| alpha.paste(pil_image, (x0, y0)) |
| |
|
|
| alpha = np.array(alpha).astype(np.uint8) |
| last_result["cache_key"] = cache_key |
| last_result["algorithm"] = algorithm |
| last_result["alpha"] = alpha |
|
|
| |
| image = np.array(image) |
| kernel = np.ones((morph_op_factor, morph_op_factor), np.uint8) |
| if morph_op == "Dilate": |
| alpha = cv2.dilate(alpha, kernel, iterations=int(morph_op_factor)) |
| elif morph_op == "Erode": |
| alpha = cv2.erode(alpha, kernel, iterations=int(morph_op_factor)) |
| else: |
| alpha = alpha |
| alpha = (alpha / 255).astype("float32") |
|
|
| image = (image / 255.0).astype("float32") |
| fg = estimate_foreground_ml(image, alpha) |
|
|
| if result_type == "Remove BG": |
| result = fg |
| elif result_type == "Replace BG": |
| bg_r = int(bg_color[1:3], base=16) |
| bg_g = int(bg_color[3:5], base=16) |
| bg_b = int(bg_color[5:7], base=16) |
|
|
| bg = np.zeros_like(fg) |
| bg[:, :, 0] = bg_r / 255. |
| bg[:, :, 1] = bg_g / 255. |
| bg[:, :, 2] = bg_b / 255. |
|
|
| result = alpha * image + (1 - alpha) * bg |
| result = np.clip(result, 0, 1) |
| else: |
| result = alpha |
|
|
| return result |
|
|
|
|
| def main(): |
| with gr.Blocks() as app: |
| gr.Markdown("Salient Object Matting") |
|
|
| with gr.Row(variant="panel"): |
| image_input = gr.Image(type='pil') |
| image_output = gr.Image() |
|
|
| with gr.Row(variant="panel"): |
| result_type = gr.Radio( |
| label="Mode", |
| show_label=True, |
| choices=[ |
| "Remove BG", |
| "Replace BG", |
| "Generate Mask", |
| ], |
| value="Remove BG", |
| ) |
| bg_color = gr.ColorPicker( |
| label="BG Color", |
| show_label=True, |
| value="#000000", |
| ) |
| algorithm = gr.Dropdown( |
| label="Algorithm", |
| show_label=True, |
| choices=model_names, |
| value="matting" |
| ) |
|
|
| with gr.Row(variant="panel"): |
| morph_op = gr.Radio( |
| label="Post-process", |
| show_label=True, |
| choices=[ |
| "None", |
| "Erode", |
| "Dilate", |
| ], |
| value="None", |
| ) |
|
|
| morph_op_factor = gr.Slider( |
| label="Factor", |
| show_label=True, |
| minimum=3, |
| maximum=20, |
| value=3, |
| step=2, |
| ) |
|
|
| run_button = gr.Button("Run") |
|
|
| run_button.click( |
| image_matting, |
| inputs=[ |
| image_input, |
| result_type, |
| bg_color, |
| algorithm, |
| morph_op, |
| morph_op_factor, |
| ], |
| outputs=image_output, |
| ) |
|
|
| app.launch() |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|