object_remover / src /core.py
LogicGoInfotechSpaces's picture
Ensure mask matches resized image size
d7d0150
import base64
import json
import os
import re
import time
import uuid
from io import BytesIO
from pathlib import Path
import cv2
# For inpainting
import numpy as np
import pandas as pd
import streamlit as st
from PIL import Image
from streamlit_drawable_canvas import st_canvas
import argparse
import io
import multiprocessing
from typing import Union
import torch
try:
torch._C._jit_override_can_fuse_on_cpu(False)
torch._C._jit_override_can_fuse_on_gpu(False)
torch._C._jit_set_texpr_fuser_enabled(False)
torch._C._jit_set_nvfuser_enabled(False)
except:
pass
from src.helper import (
download_model,
load_img,
norm_img,
numpy_to_bytes,
pad_img_to_modulo,
resize_max_size,
)
NUM_THREADS = str(multiprocessing.cpu_count())
os.environ["OMP_NUM_THREADS"] = NUM_THREADS
os.environ["OPENBLAS_NUM_THREADS"] = NUM_THREADS
os.environ["MKL_NUM_THREADS"] = NUM_THREADS
os.environ["VECLIB_MAXIMUM_THREADS"] = NUM_THREADS
os.environ["NUMEXPR_NUM_THREADS"] = NUM_THREADS
if os.environ.get("CACHE_DIR"):
os.environ["TORCH_HOME"] = os.environ["CACHE_DIR"]
#BUILD_DIR = os.environ.get("LAMA_CLEANER_BUILD_DIR", "./lama_cleaner/app/build")
# For Seam-carving
from scipy import ndimage as ndi
SEAM_COLOR = np.array([255, 200, 200]) # seam visualization color (BGR)
SHOULD_DOWNSIZE = True # if True, downsize image for faster carving
DOWNSIZE_WIDTH = 500 # resized image width if SHOULD_DOWNSIZE is True
ENERGY_MASK_CONST = 100000.0 # large energy value for protective masking
MASK_THRESHOLD = 10 # minimum pixel intensity for binary mask
USE_FORWARD_ENERGY = True # if True, use forward energy algorithm
device_str = os.environ.get("DEVICE", "cuda" if torch.cuda.is_available() else "cpu")
device = torch.device(device_str)
model_path = "./assets/big-lama.pt"
model = torch.jit.load(model_path, map_location=device)
model = model.to(device)
model.eval()
########################################
# UTILITY CODE
########################################
def visualize(im, boolmask=None, rotate=False):
vis = im.astype(np.uint8)
if boolmask is not None:
vis[np.where(boolmask == False)] = SEAM_COLOR
if rotate:
vis = rotate_image(vis, False)
cv2.imshow("visualization", vis)
cv2.waitKey(1)
return vis
def resize(image, width):
dim = None
h, w = image.shape[:2]
dim = (width, int(h * width / float(w)))
image = image.astype('float32')
return cv2.resize(image, dim)
def rotate_image(image, clockwise):
k = 1 if clockwise else 3
return np.rot90(image, k)
########################################
# ENERGY FUNCTIONS
########################################
def backward_energy(im):
"""
Simple gradient magnitude energy map.
"""
xgrad = ndi.convolve1d(im, np.array([1, 0, -1]), axis=1, mode='wrap')
ygrad = ndi.convolve1d(im, np.array([1, 0, -1]), axis=0, mode='wrap')
grad_mag = np.sqrt(np.sum(xgrad**2, axis=2) + np.sum(ygrad**2, axis=2))
# vis = visualize(grad_mag)
# cv2.imwrite("backward_energy_demo.jpg", vis)
return grad_mag
def forward_energy(im):
"""
Forward energy algorithm as described in "Improved Seam Carving for Video Retargeting"
by Rubinstein, Shamir, Avidan.
Vectorized code adapted from
https://github.com/axu2/improved-seam-carving.
"""
h, w = im.shape[:2]
im = cv2.cvtColor(im.astype(np.uint8), cv2.COLOR_BGR2GRAY).astype(np.float64)
energy = np.zeros((h, w))
m = np.zeros((h, w))
U = np.roll(im, 1, axis=0)
L = np.roll(im, 1, axis=1)
R = np.roll(im, -1, axis=1)
cU = np.abs(R - L)
cL = np.abs(U - L) + cU
cR = np.abs(U - R) + cU
for i in range(1, h):
mU = m[i-1]
mL = np.roll(mU, 1)
mR = np.roll(mU, -1)
mULR = np.array([mU, mL, mR])
cULR = np.array([cU[i], cL[i], cR[i]])
mULR += cULR
argmins = np.argmin(mULR, axis=0)
m[i] = np.choose(argmins, mULR)
energy[i] = np.choose(argmins, cULR)
# vis = visualize(energy)
# cv2.imwrite("forward_energy_demo.jpg", vis)
return energy
########################################
# SEAM HELPER FUNCTIONS
########################################
def add_seam(im, seam_idx):
"""
Add a vertical seam to a 3-channel color image at the indices provided
by averaging the pixels values to the left and right of the seam.
Code adapted from https://github.com/vivianhylee/seam-carving.
"""
h, w = im.shape[:2]
output = np.zeros((h, w + 1, 3))
for row in range(h):
col = seam_idx[row]
for ch in range(3):
if col == 0:
p = np.mean(im[row, col: col + 2, ch])
output[row, col, ch] = im[row, col, ch]
output[row, col + 1, ch] = p
output[row, col + 1:, ch] = im[row, col:, ch]
else:
p = np.mean(im[row, col - 1: col + 1, ch])
output[row, : col, ch] = im[row, : col, ch]
output[row, col, ch] = p
output[row, col + 1:, ch] = im[row, col:, ch]
return output
def add_seam_grayscale(im, seam_idx):
"""
Add a vertical seam to a grayscale image at the indices provided
by averaging the pixels values to the left and right of the seam.
"""
h, w = im.shape[:2]
output = np.zeros((h, w + 1))
for row in range(h):
col = seam_idx[row]
if col == 0:
p = np.mean(im[row, col: col + 2])
output[row, col] = im[row, col]
output[row, col + 1] = p
output[row, col + 1:] = im[row, col:]
else:
p = np.mean(im[row, col - 1: col + 1])
output[row, : col] = im[row, : col]
output[row, col] = p
output[row, col + 1:] = im[row, col:]
return output
def remove_seam(im, boolmask):
h, w = im.shape[:2]
boolmask3c = np.stack([boolmask] * 3, axis=2)
return im[boolmask3c].reshape((h, w - 1, 3))
def remove_seam_grayscale(im, boolmask):
h, w = im.shape[:2]
return im[boolmask].reshape((h, w - 1))
def get_minimum_seam(im, mask=None, remove_mask=None):
"""
DP algorithm for finding the seam of minimum energy. Code adapted from
https://karthikkaranth.me/blog/implementing-seam-carving-with-python/
"""
h, w = im.shape[:2]
energyfn = forward_energy if USE_FORWARD_ENERGY else backward_energy
M = energyfn(im)
if mask is not None:
M[np.where(mask > MASK_THRESHOLD)] = ENERGY_MASK_CONST
# give removal mask priority over protective mask by using larger negative value
if remove_mask is not None:
M[np.where(remove_mask > MASK_THRESHOLD)] = -ENERGY_MASK_CONST * 100
seam_idx, boolmask = compute_shortest_path(M, im, h, w)
return np.array(seam_idx), boolmask
def compute_shortest_path(M, im, h, w):
backtrack = np.zeros_like(M, dtype=np.int_)
# populate DP matrix
for i in range(1, h):
for j in range(0, w):
if j == 0:
idx = np.argmin(M[i - 1, j:j + 2])
backtrack[i, j] = idx + j
min_energy = M[i-1, idx + j]
else:
idx = np.argmin(M[i - 1, j - 1:j + 2])
backtrack[i, j] = idx + j - 1
min_energy = M[i - 1, idx + j - 1]
M[i, j] += min_energy
# backtrack to find path
seam_idx = []
boolmask = np.ones((h, w), dtype=np.bool_)
j = np.argmin(M[-1])
for i in range(h-1, -1, -1):
boolmask[i, j] = False
seam_idx.append(j)
j = backtrack[i, j]
seam_idx.reverse()
return seam_idx, boolmask
########################################
# MAIN ALGORITHM
########################################
def seams_removal(im, num_remove, mask=None, vis=False, rot=False):
for _ in range(num_remove):
seam_idx, boolmask = get_minimum_seam(im, mask)
if vis:
visualize(im, boolmask, rotate=rot)
im = remove_seam(im, boolmask)
if mask is not None:
mask = remove_seam_grayscale(mask, boolmask)
return im, mask
def seams_insertion(im, num_add, mask=None, vis=False, rot=False):
seams_record = []
temp_im = im.copy()
temp_mask = mask.copy() if mask is not None else None
for _ in range(num_add):
seam_idx, boolmask = get_minimum_seam(temp_im, temp_mask)
if vis:
visualize(temp_im, boolmask, rotate=rot)
seams_record.append(seam_idx)
temp_im = remove_seam(temp_im, boolmask)
if temp_mask is not None:
temp_mask = remove_seam_grayscale(temp_mask, boolmask)
seams_record.reverse()
for _ in range(num_add):
seam = seams_record.pop()
im = add_seam(im, seam)
if vis:
visualize(im, rotate=rot)
if mask is not None:
mask = add_seam_grayscale(mask, seam)
# update the remaining seam indices
for remaining_seam in seams_record:
remaining_seam[np.where(remaining_seam >= seam)] += 2
return im, mask
########################################
# MAIN DRIVER FUNCTIONS
########################################
def seam_carve(im, dy, dx, mask=None, vis=False):
im = im.astype(np.float64)
h, w = im.shape[:2]
assert h + dy > 0 and w + dx > 0 and dy <= h and dx <= w
if mask is not None:
mask = mask.astype(np.float64)
output = im
if dx < 0:
output, mask = seams_removal(output, -dx, mask, vis)
elif dx > 0:
output, mask = seams_insertion(output, dx, mask, vis)
if dy < 0:
output = rotate_image(output, True)
if mask is not None:
mask = rotate_image(mask, True)
output, mask = seams_removal(output, -dy, mask, vis, rot=True)
output = rotate_image(output, False)
elif dy > 0:
output = rotate_image(output, True)
if mask is not None:
mask = rotate_image(mask, True)
output, mask = seams_insertion(output, dy, mask, vis, rot=True)
output = rotate_image(output, False)
return output
def object_removal(im, rmask, mask=None, vis=False, horizontal_removal=False):
im = im.astype(np.float64)
rmask = rmask.astype(np.float64)
if mask is not None:
mask = mask.astype(np.float64)
output = im
h, w = im.shape[:2]
if horizontal_removal:
output = rotate_image(output, True)
rmask = rotate_image(rmask, True)
if mask is not None:
mask = rotate_image(mask, True)
while len(np.where(rmask > MASK_THRESHOLD)[0]) > 0:
seam_idx, boolmask = get_minimum_seam(output, mask, rmask)
if vis:
visualize(output, boolmask, rotate=horizontal_removal)
output = remove_seam(output, boolmask)
rmask = remove_seam_grayscale(rmask, boolmask)
if mask is not None:
mask = remove_seam_grayscale(mask, boolmask)
num_add = (h if horizontal_removal else w) - output.shape[1]
output, mask = seams_insertion(output, num_add, mask, vis, rot=horizontal_removal)
if horizontal_removal:
output = rotate_image(output, False)
return output
def s_image(im,mask,vs,hs,mode="resize"):
im = cv2.cvtColor(im, cv2.COLOR_RGBA2RGB)
mask = 255-mask[:,:,3]
h, w = im.shape[:2]
if SHOULD_DOWNSIZE and w > DOWNSIZE_WIDTH:
im = resize(im, width=DOWNSIZE_WIDTH)
if mask is not None:
mask = resize(mask, width=DOWNSIZE_WIDTH)
# image resize mode
if mode=="resize":
dy = hs#reverse
dx = vs#reverse
assert dy is not None and dx is not None
output = seam_carve(im, dy, dx, mask, False)
# object removal mode
elif mode=="remove":
assert mask is not None
output = object_removal(im, mask, None, False, True)
return output
##### Inpainting helper code
def run(image, mask):
"""
image: [C, H, W]
mask: [1, H, W]
return: BGR IMAGE
"""
origin_height, origin_width = image.shape[1:]
image = pad_img_to_modulo(image, mod=8)
mask = pad_img_to_modulo(mask, mod=8)
mask = (mask > 0) * 1
image = torch.from_numpy(image).unsqueeze(0).to(device)
mask = torch.from_numpy(mask).unsqueeze(0).to(device)
start = time.time()
with torch.no_grad():
inpainted_image = model(image, mask)
print(f"process time: {(time.time() - start)*1000}ms")
cur_res = inpainted_image[0].permute(1, 2, 0).detach().cpu().numpy()
cur_res = cur_res[0:origin_height, 0:origin_width, :]
cur_res = np.clip(cur_res * 255, 0, 255).astype("uint8")
cur_res = cv2.cvtColor(cur_res, cv2.COLOR_BGR2RGB)
return cur_res
def get_args_parser():
parser = argparse.ArgumentParser()
parser.add_argument("--port", default=8080, type=int)
parser.add_argument("--device", default="cuda", type=str)
parser.add_argument("--debug", action="store_true")
return parser.parse_args()
def process_inpaint(image, mask, invert_mask=True):
"""
Process inpainting - handles both alpha-based masks and RGB-based masks.
Preserves original image quality and dimensions.
Reference: https://huggingface.co/spaces/aryadytm/remove-photo-object
"""
image = cv2.cvtColor(image, cv2.COLOR_RGBA2RGB)
original_shape = image.shape # (H, W, C)
interpolation = cv2.INTER_CUBIC
# Preserve original size - only resize if absolutely necessary for memory/performance
# Keep original quality by preserving dimensions
max_dimension = max(image.shape[:2])
# Don't resize unless image is extremely large (over 3000px) to preserve quality
if max_dimension > 3000:
size_limit = 3000
print(f"Very large image detected ({max_dimension}px), resizing to {size_limit}px for processing")
else:
size_limit = max_dimension # Keep original size to preserve quality
print(f"Preserving original image size: {max_dimension}px (no resize)")
print(f"Origin image shape: {original_shape}")
# Resize image only if needed
if size_limit < max_dimension:
image_resized = resize_max_size(image, size_limit=size_limit, interpolation=interpolation)
print(f"Resized image shape: {image_resized.shape}")
else:
image_resized = image
print(f"Image not resized: {image_resized.shape}")
image = norm_img(image_resized)
# Handle mask: check if we should use alpha channel or RGB channels
alpha_channel = mask[:,:,3]
rgb_channels = mask[:,:,:3]
# Check if alpha is meaningful (not all 255)
alpha_mean = alpha_channel.mean()
if alpha_mean < 240:
# Alpha channel is meaningful (has transparent areas)
# Reference model logic: mask = 255-mask[:,:,3]
# alpha=0 (transparent) → 255 (white/remove)
# alpha=255 (opaque) → 0 (black/keep)
mask = 255 - alpha_channel
transparent_count = int((alpha_channel < 128).sum())
print(f"Using alpha channel: {transparent_count} transparent pixels → white (to remove)")
# For alpha-based masks: invert_mask=True means keep current (white=remove is correct)
# invert_mask=False means invert (white becomes black)
if not invert_mask:
mask = 255 - mask
print(f"Applied invert_mask=False: inverted alpha-based mask")
else:
# Alpha is mostly opaque (255), use RGB channels instead
# RGB masks: white (255) = remove, black (0) = keep (standard convention)
gray = cv2.cvtColor(rgb_channels, cv2.COLOR_RGB2GRAY)
mask = (gray > 128).astype(np.uint8) * 255
white_count = int((mask > 128).sum())
print(f"Using RGB channels: {white_count} white pixels (to remove)")
# For RGB-based masks: white=remove is already correct
# invert_mask=False means we want black=remove (invert it)
if not invert_mask:
mask = 255 - mask # invert: white becomes black, black becomes white
print(f"Applied invert_mask=False: inverted RGB mask (now black=remove)")
# Resize mask to match image dimensions (always force exact match)
target_h, target_w = image_resized.shape[:2]
if mask.shape[:2] != (target_h, target_w):
mask = cv2.resize(mask, (target_w, target_h), interpolation=cv2.INTER_NEAREST)
# Debug: log final mask statistics
mask_nonzero = int((mask > 128).sum())
mask_total = mask.shape[0] * mask.shape[1]
print(f"Final mask before normalization: {mask_nonzero}/{mask_total} pixels marked for removal ({100*mask_nonzero/mask_total:.2f}%)")
if mask_nonzero < 10:
print("ERROR: Mask is empty or almost empty! Returning original image.")
# Return original image at original size
original_rgb = (image_resized * 255).astype(np.uint8)
return cv2.resize(cv2.cvtColor(original_rgb, cv2.COLOR_RGB2BGR),
(original_shape[1], original_shape[0]),
interpolation=cv2.INTER_CUBIC)
# Verify mask is correct before normalization
print(f"Mask verification: {mask_nonzero} pixels will be removed, shape: {mask.shape}")
mask = norm_img(mask)
# Verify normalized mask
mask_normalized_ones = int((mask > 0.5).sum())
print(f"After normalization: {mask_normalized_ones} pixels marked for removal (value > 0.5)")
# Run inpainting
print("Running LaMa model for inpainting...")
res_np_img = run(image, mask)
print(f"Inpainting complete. Output shape: {res_np_img.shape}")
# Verify output changed
original_for_compare = (image_resized * 255).astype(np.uint8)
original_bgr = cv2.cvtColor(original_for_compare, cv2.COLOR_RGB2BGR)
diff = np.abs(res_np_img.astype(np.float32) - original_bgr.astype(np.float32))
diff_pixels = int((diff.sum(axis=2) > 10).sum()) # Pixels that changed by more than 10 in any channel
print(f"Output verification: {diff_pixels} pixels differ from input (should be > 0 if inpainting worked)")
# Resize back to original dimensions if we resized (use LANCZOS4 for better quality)
if size_limit < max_dimension:
res_np_img = cv2.resize(res_np_img, (original_shape[1], original_shape[0]),
interpolation=cv2.INTER_LANCZOS4)
print(f"Resized output back to original size: {res_np_img.shape}")
return cv2.cvtColor(res_np_img, cv2.COLOR_BGR2RGB)