| | |
| | import sys |
| | import os |
| | from tqdm import tqdm |
| | dir_path = os.path.dirname(os.path.realpath(__file__)) |
| | sys.path.append(os.path.join(dir_path, 'core')) |
| | from pathlib import Path |
| | from data_test_flow import * |
| | from models.network_test_flow import NeuralNRT |
| | from options_test_flow import TestOptions |
| | import torch |
| | import numpy as np |
| |
|
| |
|
| |
|
| | def save_flow_numpy(filename, flow_input): |
| | np.save(filename, flow_input) |
| |
|
| |
|
| | def predict(data): |
| | with torch.no_grad(): |
| | model.eval() |
| | path_flow = data["path_flow"] |
| | src_crop_im = data["src_crop_color"].cuda() |
| | tar_crop_im = data["tar_crop_color"].cuda() |
| | src_im = data["src_color"].cuda() |
| | tar_im = data["tar_color"].cuda() |
| | src_mask = data["src_mask"].cuda() |
| | crop_param = data["Crop_param"].cuda() |
| | B = src_mask.shape[0] |
| | flow = model(src_crop_im, tar_crop_im, src_im, tar_im, crop_param) |
| | for i in range(B): |
| | flow_tmp = flow[i].cpu().numpy() * src_mask[i].cpu().numpy() |
| | save_flow_numpy(os.path.join(save_path, os.path.basename( |
| | path_flow[i])[:-6]+".npy"), flow_tmp) |
| |
|
| |
|
| | if __name__ == "__main__": |
| | width = 272 |
| | height = 480 |
| |
|
| | test_opts = TestOptions().parse() |
| | test_opts.pretrain_model_path = os.path.join( |
| | dir_path, 'pretrain_model/raft-small.pth') |
| | data_loader = CreateDataLoader(test_opts) |
| | testloader = data_loader.load_data() |
| | model_path = os.path.join(dir_path, 'sgd_NNRT_model_epoch19008_50000.pth') |
| | model = NeuralNRT(test_opts, os.path.join( |
| | dir_path, 'pretrain_model/raft-small.pth')) |
| | state_dict = torch.load(model_path) |
| |
|
| | model.CorresPred.load_state_dict(state_dict["net_C"]) |
| | model.ImportanceW.load_state_dict(state_dict["net_W"]) |
| |
|
| | model = model.cuda() |
| |
|
| | save_path = test_opts.savepath |
| | Path(save_path).mkdir(parents=True, exist_ok=True) |
| | total_length = len(testloader) |
| |
|
| | for batch_idx, data in tqdm(enumerate(testloader), total=total_length): |
| | predict(data) |
| |
|