Spaces:
Runtime error
Runtime error
| from __future__ import print_function | |
| import sys | |
| import cv2 | |
| import argparse | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| import torch.backends.cudnn as cudnn | |
| import torch.optim as optim | |
| import torch.nn.functional as F | |
| cudnn.benchmark = False | |
| class Expansion(): | |
| def __init__(self, loadmodel = 'pretrained_models/optical_expansion/robust.pth', testres = 1, maxdisp = 256, fac = 1): | |
| maxw,maxh = [int(testres*1280), int(testres*384)] | |
| max_h = int(maxh // 64 * 64) | |
| max_w = int(maxw // 64 * 64) | |
| if max_h < maxh: max_h += 64 | |
| if max_w < maxw: max_w += 64 | |
| maxh = max_h | |
| maxw = max_w | |
| mean_L = [[0.33,0.33,0.33]] | |
| mean_R = [[0.33,0.33,0.33]] | |
| # construct model, VCN-expansion | |
| from expansion.models.VCN_exp import VCN | |
| model = VCN([1, maxw, maxh], md=[int(4*(maxdisp/256)),4,4,4,4], fac=fac, | |
| exp_unc=('robust' in loadmodel)) # expansion uncertainty only in the new model | |
| model = nn.DataParallel(model, device_ids=[0]) | |
| model.cuda() | |
| if loadmodel is not None: | |
| pretrained_dict = torch.load(loadmodel) | |
| mean_L=pretrained_dict['mean_L'] | |
| mean_R=pretrained_dict['mean_R'] | |
| pretrained_dict['state_dict'] = {k:v for k,v in pretrained_dict['state_dict'].items()} | |
| model.load_state_dict(pretrained_dict['state_dict'],strict=False) | |
| else: | |
| print('dry run') | |
| model.eval() | |
| # resize | |
| maxh = 256 | |
| maxw = 256 | |
| max_h = int(maxh // 64 * 64) | |
| max_w = int(maxw // 64 * 64) | |
| if max_h < maxh: max_h += 64 | |
| if max_w < maxw: max_w += 64 | |
| # modify module according to inputs | |
| from expansion.models.VCN_exp import WarpModule, flow_reg | |
| for i in range(len(model.module.reg_modules)): | |
| model.module.reg_modules[i] = flow_reg([1,max_w//(2**(6-i)), max_h//(2**(6-i))], | |
| ent=getattr(model.module, 'flow_reg%d'%2**(6-i)).ent,\ | |
| maxdisp=getattr(model.module, 'flow_reg%d'%2**(6-i)).md,\ | |
| fac=getattr(model.module, 'flow_reg%d'%2**(6-i)).fac).cuda() | |
| for i in range(len(model.module.warp_modules)): | |
| model.module.warp_modules[i] = WarpModule([1,max_w//(2**(6-i)), max_h//(2**(6-i))]).cuda() | |
| mean_L = torch.from_numpy(np.asarray(mean_L).astype(np.float32).mean(0)[np.newaxis,:,np.newaxis,np.newaxis]).cuda() | |
| mean_R = torch.from_numpy(np.asarray(mean_R).astype(np.float32).mean(0)[np.newaxis,:,np.newaxis,np.newaxis]).cuda() | |
| self.max_h = max_h | |
| self.max_w = max_w | |
| self.model = model | |
| self.mean_L = mean_L | |
| self.mean_R = mean_R | |
| def run(self, imgL_o, imgR_o): | |
| model = self.model | |
| mean_L = self.mean_L | |
| mean_R = self.mean_R | |
| imgL_o[imgL_o<-1] = -1 | |
| imgL_o[imgL_o>1] = 1 | |
| imgR_o[imgR_o<-1] = -1 | |
| imgR_o[imgR_o>1] = 1 | |
| imgL = (imgL_o+1.)*0.5-mean_L | |
| imgR = (imgR_o*1.)*0.5-mean_R | |
| with torch.no_grad(): | |
| imgLR = torch.cat([imgL,imgR],0) | |
| model.eval() | |
| torch.cuda.synchronize() | |
| rts = model(imgLR) | |
| torch.cuda.synchronize() | |
| flow, occ, logmid, logexp = rts | |
| torch.cuda.empty_cache() | |
| return flow, logexp | |