File size: 656 Bytes
7d12390
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
import numpy as np
import torch

def generate_label(inputs, imsize=512):
    """Generate label maps from model outputs"""
    pred_batch = []
    for input in inputs:
        input = input.unsqueeze(0)
        pred = np.squeeze(input.data.max(1)[1].cpu().numpy(), axis=0)
        pred_batch.append(pred)

    pred_batch = np.array(pred_batch)
    pred_batch = torch.from_numpy(pred_batch)
            
    label_batch = []
    for p in pred_batch:
        p = p.view(1, imsize, imsize)
        label_batch.append(p.data.cpu())
    
    label_batch = torch.cat(label_batch, 0)
    label_batch = label_batch.type(torch.LongTensor)
    
    return label_batch