File size: 2,254 Bytes
352cafd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import torch
import torch.nn.functional as F

from util.util import resize_max_side


def safe_forward(model, im, seg, inter_s8=None, inter_s4=None):
    """
    Slightly pads the input image such that its length is a multiple of 8
    """
    b, _, ph, pw = seg.shape
    if (ph % 8 != 0) or (pw % 8 != 0):
        newH = ((ph//8+1)*8)
        newW = ((pw//8+1)*8)
        p_im = torch.zeros(b, 3, newH, newW).cuda()
        p_seg = torch.zeros(b, 1, newH, newW).cuda() - 1

        p_im[:,:,0:ph,0:pw] = im
        p_seg[:,:,0:ph,0:pw] = seg
        im = p_im
        seg = p_seg

        if inter_s8 is not None:
            p_inter_s8 = torch.zeros(b, 1, newH, newW).cuda() - 1
            p_inter_s8[:,:,0:ph,0:pw] = inter_s8
            inter_s8 = p_inter_s8
        if inter_s4 is not None:
            p_inter_s4 = torch.zeros(b, 1, newH, newW).cuda() - 1
            p_inter_s4[:,:,0:ph,0:pw] = inter_s4
            inter_s4 = p_inter_s4

    images = model(im, seg, inter_s8, inter_s4)
    return_im = {}

    for key in ['pred_224', 'pred_28_3', 'pred_56_2']:
        return_im[key] = images[key][:,:,0:ph,0:pw]
    del images

    return return_im

def process_high_res_im(model, im, seg, para, name=None, aggre_device='cpu:0', coord=None, cell=None):

    im = im.to(aggre_device)
    seg = seg.to(aggre_device)

    images = model(im, seg, coord, cell)

    import pdb; pdb.set_trace()
    if para['clear']:
        torch.cuda.empty_cache()

    return images


def process_im_single_pass(model, im, seg, min_size, para):
    """
    A single pass version, aka global step only.
    """

    max_size = para['L']

    _, _, h, w = im.shape
    if max(h, w) < min_size:
        im = resize_max_side(im, min_size, 'bicubic')
        seg = resize_max_side(seg, min_size, 'bilinear')

    if max(h, w) > max_size:
        im = resize_max_side(im, max_size, 'area')
        seg = resize_max_side(seg, max_size, 'area')

    images = safe_forward(model, im, seg)

    if max(h, w) < min_size:
        images['pred_224'] = F.interpolate(images['pred_224'], size=(h, w), mode='area')
    elif max(h, w) > max_size:
        images['pred_224'] = F.interpolate(images['pred_224'], size=(h, w), mode='bilinear', align_corners=False)

    return images