|
|
import base64
|
|
|
import json
|
|
|
import os
|
|
|
import re
|
|
|
import time
|
|
|
import uuid
|
|
|
from io import BytesIO
|
|
|
from pathlib import Path
|
|
|
import cv2
|
|
|
|
|
|
|
|
|
|
|
|
import numpy as np
|
|
|
import pandas as pd
|
|
|
import streamlit as st
|
|
|
from PIL import Image
|
|
|
from streamlit_drawable_canvas import st_canvas
|
|
|
|
|
|
|
|
|
import argparse
|
|
|
import io
|
|
|
import multiprocessing
|
|
|
from typing import Union
|
|
|
|
|
|
import torch
|
|
|
|
|
|
try:
|
|
|
torch._C._jit_override_can_fuse_on_cpu(False)
|
|
|
torch._C._jit_override_can_fuse_on_gpu(False)
|
|
|
torch._C._jit_set_texpr_fuser_enabled(False)
|
|
|
torch._C._jit_set_nvfuser_enabled(False)
|
|
|
except:
|
|
|
pass
|
|
|
|
|
|
from src.helper import (
|
|
|
download_model,
|
|
|
load_img,
|
|
|
norm_img,
|
|
|
numpy_to_bytes,
|
|
|
pad_img_to_modulo,
|
|
|
resize_max_size,
|
|
|
)
|
|
|
|
|
|
NUM_THREADS = str(multiprocessing.cpu_count())
|
|
|
|
|
|
os.environ["OMP_NUM_THREADS"] = NUM_THREADS
|
|
|
os.environ["OPENBLAS_NUM_THREADS"] = NUM_THREADS
|
|
|
os.environ["MKL_NUM_THREADS"] = NUM_THREADS
|
|
|
os.environ["VECLIB_MAXIMUM_THREADS"] = NUM_THREADS
|
|
|
os.environ["NUMEXPR_NUM_THREADS"] = NUM_THREADS
|
|
|
if os.environ.get("CACHE_DIR"):
|
|
|
os.environ["TORCH_HOME"] = os.environ["CACHE_DIR"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from scipy import ndimage as ndi
|
|
|
|
|
|
SEAM_COLOR = np.array([255, 200, 200])
|
|
|
SHOULD_DOWNSIZE = True
|
|
|
DOWNSIZE_WIDTH = 500
|
|
|
ENERGY_MASK_CONST = 100000.0
|
|
|
MASK_THRESHOLD = 10
|
|
|
USE_FORWARD_ENERGY = True
|
|
|
|
|
|
device = torch.device("cpu")
|
|
|
model_path = "./assets/big-lama.pt"
|
|
|
model = torch.jit.load(model_path, map_location="cpu")
|
|
|
model = model.to(device)
|
|
|
model.eval()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def visualize(im, boolmask=None, rotate=False):
|
|
|
vis = im.astype(np.uint8)
|
|
|
if boolmask is not None:
|
|
|
vis[np.where(boolmask == False)] = SEAM_COLOR
|
|
|
if rotate:
|
|
|
vis = rotate_image(vis, False)
|
|
|
cv2.imshow("visualization", vis)
|
|
|
cv2.waitKey(1)
|
|
|
return vis
|
|
|
|
|
|
def resize(image, width):
|
|
|
dim = None
|
|
|
h, w = image.shape[:2]
|
|
|
dim = (width, int(h * width / float(w)))
|
|
|
image = image.astype('float32')
|
|
|
return cv2.resize(image, dim)
|
|
|
|
|
|
def rotate_image(image, clockwise):
|
|
|
k = 1 if clockwise else 3
|
|
|
return np.rot90(image, k)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def backward_energy(im):
|
|
|
"""
|
|
|
Simple gradient magnitude energy map.
|
|
|
"""
|
|
|
xgrad = ndi.convolve1d(im, np.array([1, 0, -1]), axis=1, mode='wrap')
|
|
|
ygrad = ndi.convolve1d(im, np.array([1, 0, -1]), axis=0, mode='wrap')
|
|
|
|
|
|
grad_mag = np.sqrt(np.sum(xgrad**2, axis=2) + np.sum(ygrad**2, axis=2))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return grad_mag
|
|
|
|
|
|
def forward_energy(im):
|
|
|
"""
|
|
|
Forward energy algorithm as described in "Improved Seam Carving for Video Retargeting"
|
|
|
by Rubinstein, Shamir, Avidan.
|
|
|
Vectorized code adapted from
|
|
|
https://github.com/axu2/improved-seam-carving.
|
|
|
"""
|
|
|
h, w = im.shape[:2]
|
|
|
im = cv2.cvtColor(im.astype(np.uint8), cv2.COLOR_BGR2GRAY).astype(np.float64)
|
|
|
|
|
|
energy = np.zeros((h, w))
|
|
|
m = np.zeros((h, w))
|
|
|
|
|
|
U = np.roll(im, 1, axis=0)
|
|
|
L = np.roll(im, 1, axis=1)
|
|
|
R = np.roll(im, -1, axis=1)
|
|
|
|
|
|
cU = np.abs(R - L)
|
|
|
cL = np.abs(U - L) + cU
|
|
|
cR = np.abs(U - R) + cU
|
|
|
|
|
|
for i in range(1, h):
|
|
|
mU = m[i-1]
|
|
|
mL = np.roll(mU, 1)
|
|
|
mR = np.roll(mU, -1)
|
|
|
|
|
|
mULR = np.array([mU, mL, mR])
|
|
|
cULR = np.array([cU[i], cL[i], cR[i]])
|
|
|
mULR += cULR
|
|
|
|
|
|
argmins = np.argmin(mULR, axis=0)
|
|
|
m[i] = np.choose(argmins, mULR)
|
|
|
energy[i] = np.choose(argmins, cULR)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return energy
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def add_seam(im, seam_idx):
|
|
|
"""
|
|
|
Add a vertical seam to a 3-channel color image at the indices provided
|
|
|
by averaging the pixels values to the left and right of the seam.
|
|
|
Code adapted from https://github.com/vivianhylee/seam-carving.
|
|
|
"""
|
|
|
h, w = im.shape[:2]
|
|
|
output = np.zeros((h, w + 1, 3))
|
|
|
for row in range(h):
|
|
|
col = seam_idx[row]
|
|
|
for ch in range(3):
|
|
|
if col == 0:
|
|
|
p = np.mean(im[row, col: col + 2, ch])
|
|
|
output[row, col, ch] = im[row, col, ch]
|
|
|
output[row, col + 1, ch] = p
|
|
|
output[row, col + 1:, ch] = im[row, col:, ch]
|
|
|
else:
|
|
|
p = np.mean(im[row, col - 1: col + 1, ch])
|
|
|
output[row, : col, ch] = im[row, : col, ch]
|
|
|
output[row, col, ch] = p
|
|
|
output[row, col + 1:, ch] = im[row, col:, ch]
|
|
|
|
|
|
return output
|
|
|
|
|
|
def add_seam_grayscale(im, seam_idx):
|
|
|
"""
|
|
|
Add a vertical seam to a grayscale image at the indices provided
|
|
|
by averaging the pixels values to the left and right of the seam.
|
|
|
"""
|
|
|
h, w = im.shape[:2]
|
|
|
output = np.zeros((h, w + 1))
|
|
|
for row in range(h):
|
|
|
col = seam_idx[row]
|
|
|
if col == 0:
|
|
|
p = np.mean(im[row, col: col + 2])
|
|
|
output[row, col] = im[row, col]
|
|
|
output[row, col + 1] = p
|
|
|
output[row, col + 1:] = im[row, col:]
|
|
|
else:
|
|
|
p = np.mean(im[row, col - 1: col + 1])
|
|
|
output[row, : col] = im[row, : col]
|
|
|
output[row, col] = p
|
|
|
output[row, col + 1:] = im[row, col:]
|
|
|
|
|
|
return output
|
|
|
|
|
|
def remove_seam(im, boolmask):
|
|
|
h, w = im.shape[:2]
|
|
|
boolmask3c = np.stack([boolmask] * 3, axis=2)
|
|
|
return im[boolmask3c].reshape((h, w - 1, 3))
|
|
|
|
|
|
def remove_seam_grayscale(im, boolmask):
|
|
|
h, w = im.shape[:2]
|
|
|
return im[boolmask].reshape((h, w - 1))
|
|
|
|
|
|
def get_minimum_seam(im, mask=None, remove_mask=None):
|
|
|
"""
|
|
|
DP algorithm for finding the seam of minimum energy. Code adapted from
|
|
|
https://karthikkaranth.me/blog/implementing-seam-carving-with-python/
|
|
|
"""
|
|
|
h, w = im.shape[:2]
|
|
|
energyfn = forward_energy if USE_FORWARD_ENERGY else backward_energy
|
|
|
M = energyfn(im)
|
|
|
|
|
|
if mask is not None:
|
|
|
M[np.where(mask > MASK_THRESHOLD)] = ENERGY_MASK_CONST
|
|
|
|
|
|
|
|
|
if remove_mask is not None:
|
|
|
M[np.where(remove_mask > MASK_THRESHOLD)] = -ENERGY_MASK_CONST * 100
|
|
|
|
|
|
seam_idx, boolmask = compute_shortest_path(M, im, h, w)
|
|
|
|
|
|
return np.array(seam_idx), boolmask
|
|
|
|
|
|
def compute_shortest_path(M, im, h, w):
|
|
|
backtrack = np.zeros_like(M, dtype=np.int_)
|
|
|
|
|
|
|
|
|
|
|
|
for i in range(1, h):
|
|
|
for j in range(0, w):
|
|
|
if j == 0:
|
|
|
idx = np.argmin(M[i - 1, j:j + 2])
|
|
|
backtrack[i, j] = idx + j
|
|
|
min_energy = M[i-1, idx + j]
|
|
|
else:
|
|
|
idx = np.argmin(M[i - 1, j - 1:j + 2])
|
|
|
backtrack[i, j] = idx + j - 1
|
|
|
min_energy = M[i - 1, idx + j - 1]
|
|
|
|
|
|
M[i, j] += min_energy
|
|
|
|
|
|
|
|
|
seam_idx = []
|
|
|
boolmask = np.ones((h, w), dtype=np.bool_)
|
|
|
j = np.argmin(M[-1])
|
|
|
for i in range(h-1, -1, -1):
|
|
|
boolmask[i, j] = False
|
|
|
seam_idx.append(j)
|
|
|
j = backtrack[i, j]
|
|
|
|
|
|
seam_idx.reverse()
|
|
|
return seam_idx, boolmask
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def seams_removal(im, num_remove, mask=None, vis=False, rot=False):
|
|
|
for _ in range(num_remove):
|
|
|
seam_idx, boolmask = get_minimum_seam(im, mask)
|
|
|
if vis:
|
|
|
visualize(im, boolmask, rotate=rot)
|
|
|
im = remove_seam(im, boolmask)
|
|
|
if mask is not None:
|
|
|
mask = remove_seam_grayscale(mask, boolmask)
|
|
|
return im, mask
|
|
|
|
|
|
|
|
|
def seams_insertion(im, num_add, mask=None, vis=False, rot=False):
|
|
|
seams_record = []
|
|
|
temp_im = im.copy()
|
|
|
temp_mask = mask.copy() if mask is not None else None
|
|
|
|
|
|
for _ in range(num_add):
|
|
|
seam_idx, boolmask = get_minimum_seam(temp_im, temp_mask)
|
|
|
if vis:
|
|
|
visualize(temp_im, boolmask, rotate=rot)
|
|
|
|
|
|
seams_record.append(seam_idx)
|
|
|
temp_im = remove_seam(temp_im, boolmask)
|
|
|
if temp_mask is not None:
|
|
|
temp_mask = remove_seam_grayscale(temp_mask, boolmask)
|
|
|
|
|
|
seams_record.reverse()
|
|
|
|
|
|
for _ in range(num_add):
|
|
|
seam = seams_record.pop()
|
|
|
im = add_seam(im, seam)
|
|
|
if vis:
|
|
|
visualize(im, rotate=rot)
|
|
|
if mask is not None:
|
|
|
mask = add_seam_grayscale(mask, seam)
|
|
|
|
|
|
|
|
|
for remaining_seam in seams_record:
|
|
|
remaining_seam[np.where(remaining_seam >= seam)] += 2
|
|
|
|
|
|
return im, mask
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def seam_carve(im, dy, dx, mask=None, vis=False):
|
|
|
im = im.astype(np.float64)
|
|
|
h, w = im.shape[:2]
|
|
|
assert h + dy > 0 and w + dx > 0 and dy <= h and dx <= w
|
|
|
|
|
|
if mask is not None:
|
|
|
mask = mask.astype(np.float64)
|
|
|
|
|
|
output = im
|
|
|
|
|
|
if dx < 0:
|
|
|
output, mask = seams_removal(output, -dx, mask, vis)
|
|
|
|
|
|
elif dx > 0:
|
|
|
output, mask = seams_insertion(output, dx, mask, vis)
|
|
|
|
|
|
if dy < 0:
|
|
|
output = rotate_image(output, True)
|
|
|
if mask is not None:
|
|
|
mask = rotate_image(mask, True)
|
|
|
output, mask = seams_removal(output, -dy, mask, vis, rot=True)
|
|
|
output = rotate_image(output, False)
|
|
|
|
|
|
elif dy > 0:
|
|
|
output = rotate_image(output, True)
|
|
|
if mask is not None:
|
|
|
mask = rotate_image(mask, True)
|
|
|
output, mask = seams_insertion(output, dy, mask, vis, rot=True)
|
|
|
output = rotate_image(output, False)
|
|
|
|
|
|
return output
|
|
|
|
|
|
|
|
|
def object_removal(im, rmask, mask=None, vis=False, horizontal_removal=False):
|
|
|
im = im.astype(np.float64)
|
|
|
rmask = rmask.astype(np.float64)
|
|
|
if mask is not None:
|
|
|
mask = mask.astype(np.float64)
|
|
|
output = im
|
|
|
|
|
|
h, w = im.shape[:2]
|
|
|
|
|
|
if horizontal_removal:
|
|
|
output = rotate_image(output, True)
|
|
|
rmask = rotate_image(rmask, True)
|
|
|
if mask is not None:
|
|
|
mask = rotate_image(mask, True)
|
|
|
|
|
|
while len(np.where(rmask > MASK_THRESHOLD)[0]) > 0:
|
|
|
seam_idx, boolmask = get_minimum_seam(output, mask, rmask)
|
|
|
if vis:
|
|
|
visualize(output, boolmask, rotate=horizontal_removal)
|
|
|
output = remove_seam(output, boolmask)
|
|
|
rmask = remove_seam_grayscale(rmask, boolmask)
|
|
|
if mask is not None:
|
|
|
mask = remove_seam_grayscale(mask, boolmask)
|
|
|
|
|
|
num_add = (h if horizontal_removal else w) - output.shape[1]
|
|
|
output, mask = seams_insertion(output, num_add, mask, vis, rot=horizontal_removal)
|
|
|
if horizontal_removal:
|
|
|
output = rotate_image(output, False)
|
|
|
|
|
|
return output
|
|
|
|
|
|
|
|
|
|
|
|
def s_image(im,mask,vs,hs,mode="resize"):
|
|
|
im = cv2.cvtColor(im, cv2.COLOR_RGBA2RGB)
|
|
|
mask = 255-mask[:,:,3]
|
|
|
h, w = im.shape[:2]
|
|
|
if SHOULD_DOWNSIZE and w > DOWNSIZE_WIDTH:
|
|
|
im = resize(im, width=DOWNSIZE_WIDTH)
|
|
|
if mask is not None:
|
|
|
mask = resize(mask, width=DOWNSIZE_WIDTH)
|
|
|
|
|
|
|
|
|
if mode=="resize":
|
|
|
dy = hs
|
|
|
dx = vs
|
|
|
assert dy is not None and dx is not None
|
|
|
output = seam_carve(im, dy, dx, mask, False)
|
|
|
|
|
|
|
|
|
|
|
|
elif mode=="remove":
|
|
|
assert mask is not None
|
|
|
output = object_removal(im, mask, None, False, True)
|
|
|
|
|
|
return output
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def run(image, mask):
|
|
|
"""
|
|
|
image: [C, H, W]
|
|
|
mask: [1, H, W]
|
|
|
return: BGR IMAGE
|
|
|
"""
|
|
|
origin_height, origin_width = image.shape[1:]
|
|
|
image = pad_img_to_modulo(image, mod=8)
|
|
|
mask = pad_img_to_modulo(mask, mod=8)
|
|
|
|
|
|
mask = (mask > 0) * 1
|
|
|
image = torch.from_numpy(image).unsqueeze(0).to(device)
|
|
|
mask = torch.from_numpy(mask).unsqueeze(0).to(device)
|
|
|
|
|
|
start = time.time()
|
|
|
with torch.no_grad():
|
|
|
inpainted_image = model(image, mask)
|
|
|
|
|
|
print(f"process time: {(time.time() - start)*1000}ms")
|
|
|
cur_res = inpainted_image[0].permute(1, 2, 0).detach().cpu().numpy()
|
|
|
cur_res = cur_res[0:origin_height, 0:origin_width, :]
|
|
|
cur_res = np.clip(cur_res * 255, 0, 255).astype("uint8")
|
|
|
cur_res = cv2.cvtColor(cur_res, cv2.COLOR_BGR2RGB)
|
|
|
return cur_res
|
|
|
|
|
|
|
|
|
def get_args_parser():
|
|
|
parser = argparse.ArgumentParser()
|
|
|
parser.add_argument("--port", default=8080, type=int)
|
|
|
parser.add_argument("--device", default="cuda", type=str)
|
|
|
parser.add_argument("--debug", action="store_true")
|
|
|
return parser.parse_args()
|
|
|
|
|
|
|
|
|
def process_inpaint(image, mask):
|
|
|
image = cv2.cvtColor(image, cv2.COLOR_RGBA2RGB)
|
|
|
original_shape = image.shape
|
|
|
interpolation = cv2.INTER_CUBIC
|
|
|
|
|
|
|
|
|
|
|
|
size_limit = max(image.shape)
|
|
|
|
|
|
|
|
|
|
|
|
print(f"Origin image shape: {original_shape}")
|
|
|
image = resize_max_size(image, size_limit=size_limit, interpolation=interpolation)
|
|
|
print(f"Resized image shape: {image.shape}")
|
|
|
image = norm_img(image)
|
|
|
|
|
|
mask = 255-mask[:,:,3]
|
|
|
mask = resize_max_size(mask, size_limit=size_limit, interpolation=interpolation)
|
|
|
mask = norm_img(mask)
|
|
|
|
|
|
res_np_img = run(image, mask)
|
|
|
|
|
|
return cv2.cvtColor(res_np_img, cv2.COLOR_BGR2RGB) |