| import cv2 |
| import time |
| from tqdm import tqdm |
| import numpy as np |
| import onnxruntime as rt |
| import gradio as io |
|
|
|
|
| class Matting: |
| def __init__(self, model_path='modnet.onnx', input_size=(512, 512)): |
| self.model_path = model_path |
| self.sess = rt.InferenceSession(self.model_path) |
| self.input_name = self.sess.get_inputs()[0].name |
| self.label_name = self.sess.get_outputs()[0].name |
| self.input_size = input_size |
| self.txt_font = cv2.FONT_HERSHEY_PLAIN |
|
|
| def normalize(self, im, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]): |
| im = im.astype(np.float32, copy=False) / 255.0 |
| im -= mean |
| im /= std |
| return im |
|
|
| def resize(self, im, target_size=608, interp=cv2.INTER_LINEAR): |
| if isinstance(target_size, list) or isinstance(target_size, tuple): |
| w = target_size[0] |
| h = target_size[1] |
| else: |
| w = target_size |
| h = target_size |
| im = cv2.resize(im, (w, h), interpolation=interp) |
| return im |
|
|
| def preprocess(self, image, target_size=(512, 512), interp=cv2.INTER_LINEAR): |
| image = self.normalize(image) |
| image = self.resize(image, target_size=target_size, interp=interp) |
| image = np.transpose(image, [2, 0, 1]) |
| image = image[None, :, :, :] |
| return image |
|
|
| def predict_frame(self, bgr_image): |
| assert len(bgr_image.shape) == 3, "Please input RGB image." |
| raw_image = cv2.cvtColor(bgr_image, cv2.COLOR_BGR2RGB) |
| h, w, c = raw_image.shape |
| image = self.preprocess(raw_image, target_size=self.input_size) |
|
|
| pred = self.sess.run( |
| [self.label_name], |
| {self.input_name: image.astype(np.float32)} |
| )[0] |
| pred = pred[0, 0] |
| matte_np = self.resize(pred, target_size=(w, h), interp=cv2.INTER_NEAREST) |
| matte_np = np.expand_dims(matte_np, axis=-1) |
| return matte_np |
|
|
| |
| def predict_image(self, bgr_image): |
| |
| assert len(bgr_image.shape) == 3, "Please input RGB image." |
| matte_np = self.predict_frame(bgr_image) |
| matting_frame = matte_np * bgr_image + (1 - matte_np) * np.full(bgr_image.shape, 255.0) |
| matting_frame = matting_frame.astype('uint8') |
| |
| return matting_frame |
|
|
|
|
| def predict_camera(self): |
| cap_video = cv2.VideoCapture(0) |
| if not cap_video.isOpened(): |
| raise IOError("Error opening video stream or file.") |
| beg = time.time() |
| count = 0 |
| while cap_video.isOpened(): |
| ret, raw_frame = cap_video.read() |
| if ret: |
| count += 1 |
| matte_np = self.predict_frame(raw_frame) |
| matting_frame = matte_np * raw_frame + (1 - matte_np) * np.full(raw_frame.shape, 255.0) |
| matting_frame = matting_frame.astype('uint8') |
|
|
| end = time.time() |
| fps = round(count / (end - beg), 2) |
| if count >= 50: |
| count = 0 |
| beg = end |
|
|
| cv2.putText(matting_frame, "fps: " + str(fps), (20, 20), self.txt_font, 2, (0, 0, 255), 1) |
|
|
| cv2.imshow('Matting', matting_frame) |
| if cv2.waitKey(1) & 0xFF == ord('q'): |
| break |
| else: |
| break |
| cap_video.release() |
| cv2.destroyWindow() |
|
|
| def check_video(self, src_path, dst_path): |
| cap1 = cv2.VideoCapture(src_path) |
| fps1 = int(cap1.get(cv2.CAP_PROP_FPS)) |
| number_frames1 = cap1.get(cv2.CAP_PROP_FRAME_COUNT) |
| cap2 = cv2.VideoCapture(dst_path) |
| fps2 = int(cap2.get(cv2.CAP_PROP_FPS)) |
| number_frames2 = cap2.get(cv2.CAP_PROP_FRAME_COUNT) |
| assert fps1 == fps2 and number_frames1 == number_frames2, "fps or number of frames not equal." |
|
|
| def predict_video(self, video_path, save_path, threshold=2e-7): |
| |
| time_beg = time.time() |
| pre_t2 = None |
| pre_t1 = None |
|
|
| cap = cv2.VideoCapture(video_path) |
| fps = int(cap.get(cv2.CAP_PROP_FPS)) |
| size = (int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)), |
| int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))) |
| number_frames = cap.get(cv2.CAP_PROP_FRAME_COUNT) |
| print("source video fps: {}, video resolution: {}, video frames: {}".format(fps, size, number_frames)) |
| videoWriter = cv2.VideoWriter(save_path, cv2.VideoWriter_fourcc('I', '4', '2', '0'), fps, size) |
|
|
| ret, frame = cap.read() |
| with tqdm(range(int(number_frames))) as t: |
| for c in t: |
| matte_np = self.predict_frame(frame) |
| if pre_t2 is None: |
| pre_t2 = matte_np |
| elif pre_t1 is None: |
| pre_t1 = matte_np |
| |
| matting_frame = pre_t2 * frame + (1 - pre_t2) * np.full(frame.shape, 255.0) |
| videoWriter.write(matting_frame.astype('uint8')) |
| else: |
| |
| error_interval = np.mean(np.abs(pre_t2 - matte_np)) |
| error_neigh = np.mean(np.abs(pre_t1 - pre_t2)) |
| if error_interval < threshold < error_neigh: |
| pre_t1 = pre_t2 |
|
|
| matting_frame = pre_t1 * frame + (1 - pre_t1) * np.full(frame.shape, 255.0) |
| videoWriter.write(matting_frame.astype('uint8')) |
| pre_t2 = pre_t1 |
| pre_t1 = matte_np |
|
|
| ret, frame = cap.read() |
| |
| matting_frame = pre_t1 * frame + (1 - pre_t1) * np.full(frame.shape, 255.0) |
| videoWriter.write(matting_frame.astype('uint8')) |
| cap.release() |
| print("video matting over, time consume: {}, fps: {}".format(time.time() - time_beg, |
| number_frames / (time.time() - time_beg))) |
|
|
|
|
| if __name__ == '__main__': |
| model = Matting(model_path='modnet.onnx', input_size=(512, 512)) |
| interface = io.Interface(fn=model.predict_image, inputs="image", outputs="image") |
| interface.launch() |
|
|