import numpy as np import torch def generate_label(inputs, imsize=512): """Generate label maps from model outputs""" pred_batch = [] for input in inputs: input = input.unsqueeze(0) pred = np.squeeze(input.data.max(1)[1].cpu().numpy(), axis=0) pred_batch.append(pred) pred_batch = np.array(pred_batch) pred_batch = torch.from_numpy(pred_batch) label_batch = [] for p in pred_batch: p = p.view(1, imsize, imsize) label_batch.append(p.data.cpu()) label_batch = torch.cat(label_batch, 0) label_batch = label_batch.type(torch.LongTensor) return label_batch