| import torch |
| from model import MattingNetwork |
| from torch.utils.data import DataLoader |
| from torch.utils.data.dataset import Dataset |
| import glob |
| import os |
| import cv2 |
| import pdb |
| import argparse |
|
|
| class ItwDataset(Dataset): |
| def __init__(self, input_pth, step, rotate): |
|
|
| self.input_pth_list = glob.glob(os.path.join(input_pth, '*.png')) + \ |
| glob.glob(os.path.join(input_pth, '*.jpg')) |
| self.input_pth_list.sort() |
| self.input_pth_list = self.input_pth_list[::step] |
| self.rotate = rotate |
| |
| def __len__(self): |
| return len(self.input_pth_list) |
|
|
| def __getitem__(self, index): |
|
|
| render_path = self.input_pth_list[index] |
| |
| img = cv2.imread(render_path) |
| if self.rotate == '+90': |
| img = cv2.rotate(img, cv2.ROTATE_90_CLOCKWISE) |
| elif self.rotate == '-90': |
| img = cv2.rotate(img, cv2.ROTATE_90_COUNTERCLOCKWISE) |
| elif self.rotate == '180': |
| img = cv2.rotate(img, cv2.ROTATE_180) |
| img = torch.from_numpy(img) |
| img = img.permute(2,0,1)/255. |
| img = img.unsqueeze(0) |
| |
| |
| |
| |
|
|
| return { |
| 'img': img, |
| 'file_name': os.path.basename(render_path)[:-4] |
| } |
|
|
| if __name__ == '__main__': |
|
|
| parser = argparse.ArgumentParser() |
| parser.add_argument('--input_pth', type = str) |
| parser.add_argument('--output_pth', type = str) |
| parser.add_argument('--device', type = str, default = 'cpu') |
| parser.add_argument('--step', type = int, default = 1) |
| parser.add_argument('--rotate', type = str, default = '') |
| args = parser.parse_args() |
| device = torch.device(f'cuda:{args.device}') |
| downsample_ratio = 0.4 |
| model = MattingNetwork(variant='mobilenetv3').eval().to(device) |
| model.load_state_dict(torch.load('./checkpoint/rvm_mobilenetv3.pth')) |
| rec = [None] * 4 |
| frame_dataset = ItwDataset(args.input_pth, args.step, args.rotate) |
| |
| if not os.path.exists(args.output_pth): |
| os.makedirs(args.output_pth) |
| for data in frame_dataset: |
| save_img_pth = os.path.join(args.output_pth, data['file_name'] + '.png') |
| if os.path.exists(save_img_pth): |
| print(save_img_pth + ' exists!') |
| continue |
| |
| with torch.no_grad(): |
| fgr, pha, *rec = model(data['img'].to(device), *rec, downsample_ratio) |
| |
| mask_infer = torch.round(pha.repeat(1,3,1,1))*255 |
| mask_infer = mask_infer.squeeze(0).permute(1,2,0).detach().cpu().numpy() |
| |
| cv2.imwrite(save_img_pth, mask_infer) |
| print(data['file_name']) |