File size: 656 Bytes
7d12390 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 |
import numpy as np
import torch
def generate_label(inputs, imsize=512):
"""Generate label maps from model outputs"""
pred_batch = []
for input in inputs:
input = input.unsqueeze(0)
pred = np.squeeze(input.data.max(1)[1].cpu().numpy(), axis=0)
pred_batch.append(pred)
pred_batch = np.array(pred_batch)
pred_batch = torch.from_numpy(pred_batch)
label_batch = []
for p in pred_batch:
p = p.view(1, imsize, imsize)
label_batch.append(p.data.cpu())
label_batch = torch.cat(label_batch, 0)
label_batch = label_batch.type(torch.LongTensor)
return label_batch |