modnet / app.py
vietvd's picture
Duplicate from qingning/gradio-MODNet
2e3be05
Raw
History Blame Contribute Delete
6.35 kB
import cv2
import time
from tqdm import tqdm
import numpy as np
import onnxruntime as rt
import gradio as io
class Matting:
def __init__(self, model_path='modnet.onnx', input_size=(512, 512)):
self.model_path = model_path
self.sess = rt.InferenceSession(self.model_path)
self.input_name = self.sess.get_inputs()[0].name
self.label_name = self.sess.get_outputs()[0].name
self.input_size = input_size
self.txt_font = cv2.FONT_HERSHEY_PLAIN
def normalize(self, im, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]):
im = im.astype(np.float32, copy=False) / 255.0
im -= mean
im /= std
return im
def resize(self, im, target_size=608, interp=cv2.INTER_LINEAR):
if isinstance(target_size, list) or isinstance(target_size, tuple):
w = target_size[0]
h = target_size[1]
else:
w = target_size
h = target_size
im = cv2.resize(im, (w, h), interpolation=interp)
return im
def preprocess(self, image, target_size=(512, 512), interp=cv2.INTER_LINEAR):
image = self.normalize(image)
image = self.resize(image, target_size=target_size, interp=interp)
image = np.transpose(image, [2, 0, 1])
image = image[None, :, :, :]
return image
def predict_frame(self, bgr_image):
assert len(bgr_image.shape) == 3, "Please input RGB image."
raw_image = cv2.cvtColor(bgr_image, cv2.COLOR_BGR2RGB)
h, w, c = raw_image.shape
image = self.preprocess(raw_image, target_size=self.input_size)
pred = self.sess.run(
[self.label_name],
{self.input_name: image.astype(np.float32)}
)[0]
pred = pred[0, 0]
matte_np = self.resize(pred, target_size=(w, h), interp=cv2.INTER_NEAREST)
matte_np = np.expand_dims(matte_np, axis=-1)
return matte_np
# def predict_image(self, source_image_path, save_image_path):
def predict_image(self, bgr_image):
# bgr_image = cv2.imread(source_image_path)
assert len(bgr_image.shape) == 3, "Please input RGB image."
matte_np = self.predict_frame(bgr_image)
matting_frame = matte_np * bgr_image + (1 - matte_np) * np.full(bgr_image.shape, 255.0)
matting_frame = matting_frame.astype('uint8')
# cv2.imwrite(save_image_path, matting_frame)
return matting_frame
def predict_camera(self):
cap_video = cv2.VideoCapture(0)
if not cap_video.isOpened():
raise IOError("Error opening video stream or file.")
beg = time.time()
count = 0
while cap_video.isOpened():
ret, raw_frame = cap_video.read()
if ret:
count += 1
matte_np = self.predict_frame(raw_frame)
matting_frame = matte_np * raw_frame + (1 - matte_np) * np.full(raw_frame.shape, 255.0)
matting_frame = matting_frame.astype('uint8')
end = time.time()
fps = round(count / (end - beg), 2)
if count >= 50:
count = 0
beg = end
cv2.putText(matting_frame, "fps: " + str(fps), (20, 20), self.txt_font, 2, (0, 0, 255), 1)
cv2.imshow('Matting', matting_frame)
if cv2.waitKey(1) & 0xFF == ord('q'):
break
else:
break
cap_video.release()
cv2.destroyWindow()
def check_video(self, src_path, dst_path):
cap1 = cv2.VideoCapture(src_path)
fps1 = int(cap1.get(cv2.CAP_PROP_FPS))
number_frames1 = cap1.get(cv2.CAP_PROP_FRAME_COUNT)
cap2 = cv2.VideoCapture(dst_path)
fps2 = int(cap2.get(cv2.CAP_PROP_FPS))
number_frames2 = cap2.get(cv2.CAP_PROP_FRAME_COUNT)
assert fps1 == fps2 and number_frames1 == number_frames2, "fps or number of frames not equal."
def predict_video(self, video_path, save_path, threshold=2e-7):
# 使用odf策略
time_beg = time.time()
pre_t2 = None # 前2步matte
pre_t1 = None # 前1步matte
cap = cv2.VideoCapture(video_path)
fps = int(cap.get(cv2.CAP_PROP_FPS))
size = (int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)),
int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)))
number_frames = cap.get(cv2.CAP_PROP_FRAME_COUNT)
print("source video fps: {}, video resolution: {}, video frames: {}".format(fps, size, number_frames))
videoWriter = cv2.VideoWriter(save_path, cv2.VideoWriter_fourcc('I', '4', '2', '0'), fps, size)
ret, frame = cap.read()
with tqdm(range(int(number_frames))) as t:
for c in t:
matte_np = self.predict_frame(frame)
if pre_t2 is None:
pre_t2 = matte_np
elif pre_t1 is None:
pre_t1 = matte_np
# 第一帧写入
matting_frame = pre_t2 * frame + (1 - pre_t2) * np.full(frame.shape, 255.0)
videoWriter.write(matting_frame.astype('uint8'))
else:
# odf
error_interval = np.mean(np.abs(pre_t2 - matte_np))
error_neigh = np.mean(np.abs(pre_t1 - pre_t2))
if error_interval < threshold < error_neigh:
pre_t1 = pre_t2
matting_frame = pre_t1 * frame + (1 - pre_t1) * np.full(frame.shape, 255.0)
videoWriter.write(matting_frame.astype('uint8'))
pre_t2 = pre_t1
pre_t1 = matte_np
ret, frame = cap.read()
# 最后一帧写入
matting_frame = pre_t1 * frame + (1 - pre_t1) * np.full(frame.shape, 255.0)
videoWriter.write(matting_frame.astype('uint8'))
cap.release()
print("video matting over, time consume: {}, fps: {}".format(time.time() - time_beg,
number_frames / (time.time() - time_beg)))
if __name__ == '__main__':
model = Matting(model_path='modnet.onnx', input_size=(512, 512))
interface = io.Interface(fn=model.predict_image, inputs="image", outputs="image")
interface.launch()