| import os |
| import sys |
| import numpy as np |
| import torch |
| import matplotlib.pyplot as plt |
| from cityscapesscripts.helpers.labels import labels as cs_labels |
| from datasets.cityscapes import get_cs_labeldata |
| from datasets.cocostuff import get_coco_labeldata |
| from datasets.potsdam import get_pd_labeldata |
|
|
| sys.path.append(os.getcwd()) |
| import modules.transforms as transforms |
|
|
|
|
| def visualize_segmentation(img = None, |
| label = None, |
| linear = None, |
| mlp = None, |
| cluster = None, |
| dataset_name = None, |
| additional = None, |
| additional_name = None, |
| additional2 = None, |
| additional_name2 = None, |
| legend = None, |
| name = None): |
|
|
|
|
| if dataset_name == "cityscapes": |
| colormap = np.array([ |
| [128, 64, 128], |
| [244, 35, 232], |
| [250, 170, 160], |
| [230, 150, 140], |
| [70, 70, 70], |
| [102, 102, 156], |
| [190, 153, 153], |
| [180, 165, 180], |
| [150, 100, 100], |
| [150, 120, 90], |
| [153, 153, 153], |
| [153, 153, 153], |
| [250, 170, 30], |
| [220, 220, 0], |
| [107, 142, 35], |
| [152, 251, 152], |
| [70, 130, 180], |
| [220, 20, 60], |
| [255, 0, 0], |
| [0, 0, 142], |
| [0, 0, 70], |
| [0, 60, 100], |
| [0, 0, 90], |
| [0, 0, 110], |
| [0, 80, 100], |
| [0, 0, 230], |
| [119, 11, 32], |
| [0, 0, 0], |
| [220, 220, 220]]) |
| elif dataset_name == "cocostuff": |
| colormap = get_coco_labeldata()[-1] |
|
|
|
|
| orig_h, orig_w = label.cpu().shape[-2:] |
| img = img.cpu().squeeze(0).numpy().transpose(1, 2, 0) |
| img = (img-img.min())/(img-img.min()).max() |
| label = label.cpu().squeeze(0).numpy().transpose(1, 2, 0) |
| |
|
|
| label[label == 255] = 27 |
| colored_label = colormap[label.flatten()] |
| colored_label = colored_label.reshape(orig_h, orig_w, 3) |
|
|
| num_subplots = 3 |
| if linear != None: num_subplots += 1 |
| if mlp != None: num_subplots += 1 |
| if additional != None: num_subplots += 1 |
| if additional2 != None: num_subplots += 1 |
|
|
|
|
| fig = plt.figure(figsize=(8, 2), dpi=200) |
| fig.tight_layout() |
| plt.axis('off') |
| plt.subplot(1, num_subplots, 1) |
| plt.gca().set_title('Image') |
| plt.imshow(img) |
| plt.axis("off") |
| plt.subplot(1, num_subplots, 2) |
| plt.gca().set_title('Ground Truth') |
| plt.imshow(colored_label) |
| plt.axis("off") |
| i = 3 |
| if linear != None: |
| linear = linear.cpu().numpy().transpose(1, 2, 0).astype('uint8') |
| linear = colormap[linear.flatten()].reshape(linear.shape[0], linear.shape[1], 3) |
| plt.axis("off") |
| plt.subplot(1, num_subplots, i) |
| plt.gca().set_title('Linear') |
| plt.imshow(linear) |
| i+=1 |
|
|
| if mlp != None: |
| mlp = mlp.cpu().numpy().transpose(1, 2, 0).astype('uint8') |
| mlp = colormap[mlp.flatten()].reshape(mlp.shape[0], mlp.shape[1], 3) |
| plt.axis("off") |
| plt.subplot(1, num_subplots, i) |
| plt.gca().set_title('MLP') |
| plt.imshow(mlp) |
| plt.axis("off") |
| i+=1 |
|
|
| if cluster != None: |
| cluster = cluster.cpu().numpy().transpose(1, 2, 0).astype('uint8') |
| cluster = colormap[cluster.flatten()].reshape(cluster.shape[0], cluster.shape[1], 3) |
| plt.axis("off") |
| plt.subplot(1, num_subplots, i) |
| plt.gca().set_title('Cluster') |
| plt.imshow(cluster) |
| plt.axis("off") |
| i+=1 |
|
|
| if additional != None: |
| |
| additional = additional.cpu().numpy().transpose(1, 2, 0).astype('uint8') |
| additional = colormap[additional.flatten()].reshape(additional.shape[0], additional.shape[1], 3) |
| plt.axis("off") |
| plt.subplot(1, num_subplots, i) |
| plt.gca().set_title(additional_name) |
| plt.imshow(additional) |
| plt.axis("off") |
| i+=1 |
|
|
| if additional2 != None: |
| additional2 = additional2.cpu().numpy() |
| plt.axis("off") |
| plt.subplot(1, num_subplots, i) |
| plt.gca().set_title(additional_name2) |
| plt.imshow(additional2) |
| plt.axis("off") |
| i+=1 |
|
|
|
|
| |
| |
|
|
| |
|
|
| |
| |
| |
|
|
|
|
|
|
| if name != None: plt.savefig(name) |
| fig.canvas.draw() |
| |
| data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8) |
| data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) |
| plt.close('all') |
|
|
| return data |
|
|
|
|
|
|
| def visualize_confusion_matrix(cls_names, meter, name=None): |
| |
| conf_matrix = (meter.histogram/meter.histogram.sum(dim=0)) |
| conf_matrix = np.array(conf_matrix.cpu(), dtype=np.float16) |
| fig, ax = plt.subplots(figsize=(15, 15)) |
| ax.matshow(torch.Tensor(conf_matrix).fill_diagonal_(0), cmap=plt.cm.Blues, alpha=0.8) |
| for i in range(conf_matrix.shape[0]): |
| for j in range(conf_matrix.shape[1]): |
| ax.text(x=j, y=i,s=(conf_matrix[i, j]*100).round(1), va='center', ha='center', size='large') |
| ax.set_xticks(list(range(cls_names.__len__()))) |
| ax.set_xticklabels(cls_names, rotation=90, ha='center', fontsize=12) |
| ax.set_yticks(list(range(cls_names.__len__()))) |
| ax.set_yticklabels(cls_names, fontsize=12) |
| plt.xlabel('Predictions', fontsize=18) |
| plt.ylabel('Actuals', fontsize=18) |
| plt.title('Confusion Matrix', fontsize=18) |
| |
| if name != None: plt.savefig(name) |
| fig.canvas.draw() |
| |
| data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8) |
| data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) |
| plt.close('all') |
| return data |
|
|
|
|
|
|
|
|
|
|
| def batch_visualize_segmentation(img = None, |
| label = None, |
| in1 = None, |
| in2 = None, |
| in3 = None, |
| in4 = None, |
| dataset_name = None): |
|
|
|
|
| if dataset_name == "cityscapes": |
| colormap = get_cs_labeldata()[-1] |
| elif dataset_name == "cocostuff": |
| colormap = get_coco_labeldata()[-1] |
| elif dataset_name == "potsdam": |
| colormap = get_pd_labeldata()[-1] |
|
|
| def _vis_one_img(idx, img, label, ins): |
|
|
| orig_h, orig_w = label.cpu().shape[-2:] |
| img = img.cpu().numpy().transpose(1, 2, 0) |
| img = (img-img.min())/(img-img.min()).max() |
| label = label.cpu().numpy().transpose(1, 2, 0) |
| label[label > 27] = 27 |
| colored_label = colormap[label.flatten()].reshape(orig_h, orig_w, 3) |
|
|
| num_subplots = sum([1 for x in [in1, in2, in3, in4] if x != None]) + 2 |
|
|
| fig = plt.figure(figsize=(10, 2), dpi=150) |
| fig.tight_layout() |
| plt.axis('off') |
| plt.subplot(1, num_subplots, 1) |
| if idx == 0: plt.gca().set_title('Image') |
| plt.imshow(img) |
| plt.axis("off") |
| plt.subplot(1, num_subplots, 2) |
| if idx == 0: plt.gca().set_title('Ground Truth') |
| plt.imshow(colored_label) |
| plt.axis("off") |
| if ins != None: |
| i = 3 |
| for input in ins: |
| vis = input[1].cpu().numpy().transpose(1, 2, 0).astype('uint8') |
| vis = colormap[vis.flatten()].reshape(vis.shape[0], vis.shape[1], 3) |
| plt.axis("off") |
| plt.subplot(1, num_subplots, i) |
| if idx == 0: plt.gca().set_title(input[0]) |
| plt.imshow(vis) |
| plt.axis("off") |
| i+=1 |
|
|
| fig.canvas.draw() |
| plt.close('all') |
| one_vis = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8) |
| one_vis = one_vis.reshape(fig.canvas.get_width_height()[::-1] + (3,)) |
| plt.close('all') |
| return one_vis |
| |
| imgs = [] |
| for idx, (data) in enumerate(zip(img, label)): |
| imgs.append(_vis_one_img(idx, data[0], data[1], [[i[0], i[1][idx].unsqueeze(0)] for i in [in1, in2, in3, in4] if i!=None])) |
|
|
| return np.vstack(imgs) |
|
|
|
|
|
|
| def visualize_single_masks(img, |
| label, |
| data, |
| dataset_name = None): |
|
|
|
|
| if dataset_name == "cityscapes": |
| colormap = get_cs_labeldata()[-1] |
| elif dataset_name == "cocostuff": |
| colormap = get_coco_labeldata()[-1] |
| elif dataset_name == "potsdam": |
| colormap = get_pd_labeldata()[-1] |
| |
| |
| fig = plt.figure(figsize=(data['sim'].__len__()*2, 7*2), dpi=150) |
| fig.tight_layout() |
| for indx, (sim, nnsim, nnsim_thresh, crf, pamr, mask) in enumerate(zip(data['sim'], data['nnsim'], data['nnsim_tresh'], data['crf'], data['pamr'], data['outmask'])): |
| rows = data['sim'].__len__() |
| cols = 8 |
| plotlabel=colormap[label.squeeze(0).squeeze(0).int().cpu()] |
| plt.subplot(rows, cols, 1+(indx*cols)) |
| img = (img-img.min())/(img.max()-img.min()) |
| if indx == 0: plt.title('Image') |
| plt.imshow(img.squeeze(0).permute(1, 2, 0).cpu()) |
| plt.axis('off') |
| plt.subplot(rows, cols, 2+(indx*cols)) |
| if indx == 0: plt.title('GT') |
| plt.imshow(plotlabel) |
| plt.axis('off') |
| plt.subplot(rows, cols, 3+(indx*cols)) |
| if indx == 0: plt.title('1.Eig') |
| plt.imshow(sim.cpu().numpy()) |
| plt.axis('off') |
| plt.subplot(rows, cols, 4+(indx*cols)) |
| if indx == 0: plt.title('1.EigNN') |
| plt.imshow(nnsim.cpu().numpy()) |
| plt.axis('off') |
| plt.subplot(rows, cols, 5+(indx*cols)) |
| if indx == 0: plt.title('+Thresh') |
| plt.imshow(nnsim_thresh) |
| plt.axis('off') |
| plt.subplot(rows, cols, 6+(indx*cols)) |
| if indx == 0: plt.title('+CRF') |
| plt.imshow(crf) |
| plt.axis('off') |
| plt.subplot(rows, cols, 7+(indx*cols)) |
| if indx == 0: plt.title('PAMR') |
| plt.imshow(pamr.squeeze().cpu().numpy()) |
| plt.axis('off') |
| plt.subplot(rows, cols, 8+(indx*cols)) |
| if indx == 0: plt.title('Mask') |
| mask[0, 0] = 0 |
| plt.imshow(mask.numpy(), cmap='Greys') |
| plt.axis('off') |
| |
| |
| |
| fig.canvas.draw() |
| plt.close('all') |
| one_vis = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8) |
| one_vis = one_vis.reshape(fig.canvas.get_width_height()[::-1] + (3,)) |
| plt.close('all') |
| return one_vis |
|
|
|
|
|
|
|
|
|
|
| def visualize_pseudo_paper(img, |
| label, |
| pseudo_gt, |
| pseudo_plain, |
| dataset_name = None, |
| save_name = None): |
|
|
|
|
| if dataset_name == "cityscapes": |
| colormap = get_cs_labeldata()[-1] |
| elif dataset_name == "cocostuff": |
| colormap = get_coco_labeldata()[-1] |
| elif dataset_name == "potsdam": |
| colormap = get_pd_labeldata()[-1] |
| |
| |
| np.random.seed(0) |
| cb_colomap = np.array([list(np.random.randint(0, 255, size=(1,3))[0]) for _ in range(400)]+[[0, 0, 0]]) |
| pseudo_plain = pseudo_plain.int().cpu() |
| pseudo_plain[pseudo_plain==255] = 400 |
| pseudo_plain = cb_colomap[pseudo_plain.int().cpu()].squeeze() |
| |
|
|
| |
| |
| fig = plt.figure(figsize=(8, 2), dpi=150) |
| fig.subplots_adjust(left=0.1, |
| bottom=0.1, |
| right=0.5, |
| top=0.5, |
| wspace=0.05, |
| hspace=0.0) |
|
|
| plt.subplot(1, 4, 1) |
| img = (img-img.min())/(img.max()-img.min()) |
| img = img.squeeze(0).permute(1, 2, 0).cpu() |
| plt.imshow(img) |
| plt.axis('off') |
| |
| plt.subplot(1, 4, 2) |
| plotlabel=colormap[label.squeeze(0).squeeze(0).int().cpu()] |
| plt.imshow(plotlabel) |
| plt.axis('off') |
| |
| plt.subplot(1, 4, 3) |
| plotpseudo=colormap[pseudo_gt.squeeze(0).squeeze(0).int().cpu()] |
| |
| |
| plt.imshow(plotpseudo) |
| plt.axis('off') |
| |
| plt.subplot(1, 4, 4) |
| plt.imshow(pseudo_plain) |
| plt.axis('off') |
| plt.savefig(save_name+'.pdf', bbox_inches='tight', pad_inches=0.0) |
| |
|
|
| save_name_single = os.path.join(os.path.dirname(save_name), 'singleimgs/') |
| os.makedirs(os.path.dirname(save_name_single), exist_ok=True) |
| for i, n in zip([img, plotlabel, plotpseudo, pseudo_plain], ['img', 'gt', 'pseudo', 'pseudoc']): |
| fig = plt.figure(figsize=(2, 2), dpi=300) |
| plt.imshow(i) |
| plt.axis('off') |
| plt.savefig(os.path.join(save_name_single, os.path.split(save_name)[-1]+'_'+n+'.png'), bbox_inches='tight', pad_inches=0.0) |
| |
| |
| |
| |
| |
| def logits_to_image(logits = None, |
| img = None, |
| label = None, |
| dataset_name = None, |
| save_path = None, |
| save_imggt = False): |
|
|
|
|
| if dataset_name == "cityscapes": |
| colormap = get_cs_labeldata()[-1] |
| elif dataset_name == "cocostuff": |
| colormap = get_coco_labeldata()[-1] |
| elif dataset_name == "potsdam": |
| colormap = get_pd_labeldata()[-1] |
| |
| vis = logits.cpu().numpy().transpose(1, 2, 0).astype('uint8') |
| vis = colormap[vis.flatten()].reshape(vis.shape[0], vis.shape[1], 3) |
| |
| fig = plt.figure(figsize=(2, 2), dpi=400) |
| fig.tight_layout() |
| plt.subplot(1, 1, 1) |
| plt.imshow(vis) |
| plt.axis("off") |
| plt.savefig(save_path+'_pred.png', bbox_inches='tight', pad_inches=0.0) |
| plt.close('all') |
| |
| if save_imggt: |
| orig_h, orig_w = label.cpu().shape[-2:] |
| img = img.cpu().numpy().transpose(1, 2, 0) |
| img = (img-img.min())/(img-img.min()).max() |
| label = label.cpu().numpy().transpose(1, 2, 0) |
| label[label > 27] = 27 |
| colored_label = colormap[label.flatten()].reshape(orig_h, orig_w, 3) |
|
|
| fig = plt.figure(figsize=(2, 2), dpi=400) |
| fig.tight_layout() |
| plt.subplot(1, 1, 1) |
| plt.imshow(img) |
| plt.axis("off") |
| plt.savefig(save_path+'_img.png', bbox_inches='tight', pad_inches=0.0) |
| plt.close('all') |
| |
| fig = plt.figure(figsize=(2, 2), dpi=400) |
| fig.tight_layout() |
| plt.subplot(1, 1, 1) |
| plt.imshow(colored_label) |
| plt.axis("off") |
| plt.savefig(save_path+'_gt.png', bbox_inches='tight', pad_inches=0.0) |
| plt.close('all') |
| |
| |
| |
| |
| class Vis_Demo(): |
| def __init__(self): |
| super(Vis_Demo, self).__init__() |
| self.colormap = get_coco_labeldata()[-1] |
|
|
| def apply_colors(self, logits): |
| vis = logits.cpu().numpy().transpose(1, 2, 0).astype('uint8') |
| vis = self.colormap[vis.flatten()].reshape(vis.shape[0], vis.shape[1], 3) |
| return vis |
| |
| |
| |
| def visualize_demo(img, pseudo, alpha = 0.5): |
| np.random.seed(0) |
| cb_colomap = np.array([list(np.random.randint(0, 255, size=(1,3))[0]) for _ in range(400)]+[[0, 0, 0]]) |
| pseudo_plain = pseudo.long().cpu().numpy() |
| pseudo_plain[pseudo_plain==255] = 400 |
| pseudo_plain = cb_colomap[pseudo_plain].squeeze() |
| |
| img = transforms.UnNormalize()(img)*255 |
| img = img.permute(1, 2, 0).long().cpu().numpy() |
| out = alpha*img + (1-alpha)*pseudo_plain |
| |
| return np.array(out, dtype=np.uint8) |
| |
|
|