Spaces:
Build error
Build error
| # coding=utf-8 | |
| import os | |
| import sys | |
| sys.path.append(os.path.abspath(os.path.join(__file__, '..', '..'))) | |
| import argparse | |
| import os | |
| import cv2 | |
| import glob | |
| import copy | |
| import numpy as np | |
| import torch | |
| from PIL import Image | |
| import scipy.ndimage | |
| import torchvision.transforms.functional as F | |
| import torch.nn.functional as F2 | |
| from RAFT import utils | |
| from RAFT import RAFT | |
| import utils.region_fill as rf | |
| from torchvision.transforms import ToTensor | |
| import time | |
| def to_tensor(img): | |
| img = Image.fromarray(img) | |
| img_t = F.to_tensor(img).float() | |
| return img_t | |
| def gradient_mask(mask): # 产生梯度的mask | |
| gradient_mask = np.logical_or.reduce((mask, | |
| np.concatenate((mask[1:, :], np.zeros((1, mask.shape[1]), dtype=np.bool)), | |
| axis=0), | |
| np.concatenate((mask[:, 1:], np.zeros((mask.shape[0], 1), dtype=np.bool)), | |
| axis=1))) | |
| return gradient_mask | |
| def create_dir(dir): | |
| """Creates a directory if not exist. | |
| """ | |
| if not os.path.exists(dir): | |
| os.makedirs(dir) | |
| def initialize_RAFT(args): | |
| """Initializes the RAFT model. | |
| """ | |
| model = torch.nn.DataParallel(RAFT(args)) | |
| model.load_state_dict(torch.load(args.model)) | |
| model = model.module | |
| model.to('cuda') | |
| model.eval() | |
| return model | |
| def calculate_flow(args, model, vid, video, mode): | |
| """Calculates optical flow. | |
| """ | |
| if mode not in ['forward', 'backward']: | |
| raise NotImplementedError | |
| nFrame, _, imgH, imgW = video.shape | |
| Flow = np.empty(((imgH, imgW, 2, 0)), dtype=np.float32) | |
| create_dir(os.path.join(args.outroot, vid, mode + '_flo')) | |
| # create_dir(os.path.join(args.outroot, vid, 'flow', mode + '_png')) | |
| with torch.no_grad(): | |
| for i in range(video.shape[0] - 1): | |
| print("Calculating {0} flow {1:2d} <---> {2:2d}".format(mode, i, i + 1), '\r', end='') | |
| if mode == 'forward': | |
| # Flow i -> i + 1 | |
| image1 = video[i, None] | |
| image2 = video[i + 1, None] | |
| elif mode == 'backward': | |
| # Flow i + 1 -> i | |
| image1 = video[i + 1, None] | |
| image2 = video[i, None] | |
| else: | |
| raise NotImplementedError | |
| _, flow = model(image1, image2, iters=20, test_mode=True) | |
| flow = flow[0].permute(1, 2, 0).cpu().numpy() | |
| # Flow = np.concatenate((Flow, flow[..., None]), axis=-1) | |
| # Flow visualization. | |
| # flow_img = utils.flow_viz.flow_to_image(flow) | |
| # flow_img = Image.fromarray(flow_img) | |
| # Saves the flow and flow_img. | |
| # flow_img.save(os.path.join(args.outroot, vid, 'flow', mode + '_png', '%05d.png'%i)) | |
| utils.frame_utils.writeFlow(os.path.join(args.outroot, vid, mode + '_flo', '%05d.flo' % i), flow) | |
| def main(args): | |
| # Flow model. | |
| RAFT_model = initialize_RAFT(args) | |
| videos = os.listdir(args.path) | |
| videoLen = len(videos) | |
| try: | |
| exceptList = os.listdir(args.expdir) | |
| except: | |
| exceptList = [] | |
| v = 0 | |
| for vid in videos: | |
| v += 1 | |
| print('[{}]/[{}] Video {} is being processed'.format(v, len(videos), vid)) | |
| if vid in exceptList: | |
| print('Video: {} skipped'.format(vid)) | |
| continue | |
| # Loads frames. | |
| filename_list = glob.glob(os.path.join(args.path, vid, '*.png')) + \ | |
| glob.glob(os.path.join(args.path, vid, '*.jpg')) | |
| # Obtains imgH, imgW and nFrame. | |
| imgH, imgW = np.array(Image.open(filename_list[0])).shape[:2] | |
| nFrame = len(filename_list) | |
| print('images are loaded') | |
| # Loads video. | |
| video = [] | |
| for filename in sorted(filename_list): | |
| print(filename) | |
| img = np.array(Image.open(filename)) | |
| if args.width != 0 and args.height != 0: | |
| img = cv2.resize(img, (args.width, args.height), cv2.INTER_LINEAR) | |
| video.append(torch.from_numpy(img.astype(np.uint8)).permute(2, 0, 1).float()) | |
| video = torch.stack(video, dim=0) | |
| video = video.to('cuda') | |
| # Calcutes the corrupted flow. | |
| start = time.time() | |
| calculate_flow(args, RAFT_model, vid, video, 'forward') | |
| calculate_flow(args, RAFT_model, vid, video, 'backward') | |
| end = time.time() | |
| sumTime = end - start | |
| print('{}/{}, video {} is finished. {} frames takes {}s, {}s/frame.'.format(v, videoLen, vid, nFrame, sumTime, | |
| sumTime / (2 * nFrame))) | |
| if __name__ == '__main__': | |
| parser = argparse.ArgumentParser() | |
| # flow basic setting | |
| parser.add_argument('--path', required=True, type=str) | |
| parser.add_argument('--expdir', type=str) | |
| parser.add_argument('--outroot', required=True, type=str) | |
| parser.add_argument('--width', type=int, default=432) | |
| parser.add_argument('--height', type=int, default=256) | |
| # RAFT | |
| parser.add_argument('--model', default='../weight/raft-things.pth', help="restore checkpoint") | |
| parser.add_argument('--small', action='store_true', help='use small model') | |
| parser.add_argument('--mixed_precision', action='store_true', help='use mixed precision') | |
| parser.add_argument('--alternate_corr', action='store_true', help='use efficent correlation implementation') | |
| args = parser.parse_args() | |
| main(args) | |