drozdgk's picture
chore: vendor third_party (remove submodules, ignore artifacts)
352cafd
import torch
import torch.nn.functional as F
from util.util import resize_max_side
def safe_forward(model, im, seg, inter_s8=None, inter_s4=None):
"""
Slightly pads the input image such that its length is a multiple of 8
"""
b, _, ph, pw = seg.shape
if (ph % 8 != 0) or (pw % 8 != 0):
newH = ((ph//8+1)*8)
newW = ((pw//8+1)*8)
p_im = torch.zeros(b, 3, newH, newW).cuda()
p_seg = torch.zeros(b, 1, newH, newW).cuda() - 1
p_im[:,:,0:ph,0:pw] = im
p_seg[:,:,0:ph,0:pw] = seg
im = p_im
seg = p_seg
if inter_s8 is not None:
p_inter_s8 = torch.zeros(b, 1, newH, newW).cuda() - 1
p_inter_s8[:,:,0:ph,0:pw] = inter_s8
inter_s8 = p_inter_s8
if inter_s4 is not None:
p_inter_s4 = torch.zeros(b, 1, newH, newW).cuda() - 1
p_inter_s4[:,:,0:ph,0:pw] = inter_s4
inter_s4 = p_inter_s4
images = model(im, seg, inter_s8, inter_s4)
return_im = {}
for key in ['pred_224', 'pred_28_3', 'pred_56_2']:
return_im[key] = images[key][:,:,0:ph,0:pw]
del images
return return_im
def process_high_res_im(model, im, seg, para, name=None, aggre_device='cpu:0', coord=None, cell=None):
im = im.to(aggre_device)
seg = seg.to(aggre_device)
images = model(im, seg, coord, cell)
import pdb; pdb.set_trace()
if para['clear']:
torch.cuda.empty_cache()
return images
def process_im_single_pass(model, im, seg, min_size, para):
"""
A single pass version, aka global step only.
"""
max_size = para['L']
_, _, h, w = im.shape
if max(h, w) < min_size:
im = resize_max_side(im, min_size, 'bicubic')
seg = resize_max_side(seg, min_size, 'bilinear')
if max(h, w) > max_size:
im = resize_max_side(im, max_size, 'area')
seg = resize_max_side(seg, max_size, 'area')
images = safe_forward(model, im, seg)
if max(h, w) < min_size:
images['pred_224'] = F.interpolate(images['pred_224'], size=(h, w), mode='area')
elif max(h, w) > max_size:
images['pred_224'] = F.interpolate(images['pred_224'], size=(h, w), mode='bilinear', align_corners=False)
return images