test / utils.py
danicor's picture
Update utils.py
bc5c2cc verified
raw
history blame contribute delete
656 Bytes
import numpy as np
import torch
def generate_label(inputs, imsize=512):
"""Generate label maps from model outputs"""
pred_batch = []
for input in inputs:
input = input.unsqueeze(0)
pred = np.squeeze(input.data.max(1)[1].cpu().numpy(), axis=0)
pred_batch.append(pred)
pred_batch = np.array(pred_batch)
pred_batch = torch.from_numpy(pred_batch)
label_batch = []
for p in pred_batch:
p = p.view(1, imsize, imsize)
label_batch.append(p.data.cpu())
label_batch = torch.cat(label_batch, 0)
label_batch = label_batch.type(torch.LongTensor)
return label_batch