Spaces:
Build error
Build error
File size: 4,148 Bytes
22755c5 a15e93c 22755c5 0d73319 22755c5 a15e93c 0d73319 22755c5 a15e93c db77e37 8e4a5f8 db77e37 8e4a5f8 22755c5 c2adab2 22755c5 0d73319 22755c5 b15d1ff 22755c5 b15d1ff 22755c5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 |
import cv2
import gradio as gr
import numpy as np
import onnxruntime
import requests
from huggingface_hub import hf_hub_download
from PIL import Image
import streamlit as st
# Get x_scale_factor & y_scale_factor to resize image
def get_scale_factor(im_h, im_w, ref_size=512):
if max(im_h, im_w) < ref_size or min(im_h, im_w) > ref_size:
if im_w >= im_h:
im_rh = ref_size
im_rw = int(im_w / im_h * ref_size)
elif im_w < im_h:
im_rw = ref_size
im_rh = int(im_h / im_w * ref_size)
else:
im_rh = im_h
im_rw = im_w
im_rw = im_rw - im_rw % 32
im_rh = im_rh - im_rh % 32
x_scale_factor = im_rw / im_w
y_scale_factor = im_rh / im_h
return x_scale_factor, y_scale_factor
MODEL_PATH = hf_hub_download('huedaya/background-remover-files', 'modnet.onnx', repo_type='dataset')
def main(image_path, threshold, api):
# read image
im = cv2.imread(image_path)
im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
# load secret
apiKey = st.secrets["Api-Key"]
if apiKey != api:
image = Image.open(image_path)
image = np.asarray(image)
return Image.fromarray(image)
# unify image channels to 3
if len(im.shape) == 2:
im = im[:, :, None]
if im.shape[2] == 1:
im = np.repeat(im, 3, axis=2)
elif im.shape[2] == 4:
im = im[:, :, 0:3]
# normalize values to scale it between -1 to 1
im = (im - 127.5) / 127.5
im_h, im_w, im_c = im.shape
x, y = get_scale_factor(im_h, im_w)
# resize image
im = cv2.resize(im, None, fx=x, fy=y, interpolation=cv2.INTER_AREA)
# prepare input shape
im = np.transpose(im)
im = np.swapaxes(im, 1, 2)
im = np.expand_dims(im, axis=0).astype('float32')
# Initialize session and get prediction
session = onnxruntime.InferenceSession(MODEL_PATH, None)
input_name = session.get_inputs()[0].name
output_name = session.get_outputs()[0].name
result = session.run([output_name], {input_name: im})
# refine matte
matte = (np.squeeze(result[0]) * 255).astype('uint8')
matte = cv2.resize(matte, dsize=(im_w, im_h), interpolation=cv2.INTER_AREA)
# HACK - Could probably just convert this to PIL instead of writing
cv2.imwrite('out.png', matte)
image = Image.open(image_path)
matte = Image.open('out.png')
# obtain predicted foreground
image = np.asarray(image)
if len(image.shape) == 2:
image = image[:, :, None]
if image.shape[2] == 1:
image = np.repeat(image, 3, axis=2)
elif image.shape[2] == 4:
image = image[:, :, 0:3]
b, g, r = cv2.split(image)
mask = np.asarray(matte)
a = np.ones(mask.shape, dtype='uint8') * 255
alpha_im = cv2.merge([b, g, r, a], 4)
bg = np.zeros(alpha_im.shape)
new_mask = np.stack([mask, mask, mask, mask], axis=2)
foreground = np.where(new_mask > threshold, alpha_im, bg).astype(np.uint8)
return Image.fromarray(foreground)
title = "MODNet Background Remover"
description = "Gradio demo for MODNet, a model that can remove the background from a given image. To use it, simply upload your image, or click one of the examples to load them. Read more at the links below."
article = "<div style='text-align: center;'> <a href='https://github.com/ZHKKKe/MODNet' target='_blank'>Github Repo</a> | <a href='https://arxiv.org/abs/2011.11961' target='_blank'>MODNet: Real-Time Trimap-Free Portrait Matting via Objective Decomposition</a> </div>"
url = "https://huggingface.co/datasets/huedaya/background-remover-files/resolve/main/test1.jpg"
image = Image.open(requests.get(url, stream=True).raw)
image.save('test1.jpg')
interface = gr.Interface(
fn=main,
inputs=[
gr.inputs.Image(type='filepath'),
gr.inputs.Slider(minimum=0, maximum=250, default=100, step=5, label='Mask Cutoff Threshold'),
gr.inputs.Textbox(label='API-Key'),
],
outputs='image',
examples=[['test1.jpg', 120, 'test']],
title=title,
description=description,
article=article,
)
if __name__ == '__main__':
interface.launch(debug=True)
|