| | |
| | ''' |
| | @File : visualizer.py |
| | @Time : 2022/04/05 11:39:33 |
| | @Author : Shilong Liu |
| | @Contact : liusl20@mail.tsinghua.edu.cn; slongliu86@gmail.com |
| | Modified from COCO evaluator |
| | ''' |
| |
|
| | import os, sys |
| | from textwrap import wrap |
| | import torch |
| | import numpy as np |
| | import cv2 |
| | import datetime |
| |
|
| | import matplotlib.pyplot as plt |
| | from matplotlib.collections import PatchCollection |
| | from matplotlib.patches import Polygon |
| | from pycocotools import mask as maskUtils |
| | from matplotlib import transforms |
| |
|
| | def renorm(img: torch.FloatTensor, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) \ |
| | -> torch.FloatTensor: |
| | |
| | |
| | assert img.dim() == 3 or img.dim() == 4, "img.dim() should be 3 or 4 but %d" % img.dim() |
| | if img.dim() == 3: |
| | assert img.size(0) == 3, 'img.size(0) shoule be 3 but "%d". (%s)' % (img.size(0), str(img.size())) |
| | img_perm = img.permute(1,2,0) |
| | mean = torch.Tensor(mean) |
| | std = torch.Tensor(std) |
| | img_res = img_perm * std + mean |
| | return img_res.permute(2,0,1) |
| | else: |
| | assert img.size(1) == 3, 'img.size(1) shoule be 3 but "%d". (%s)' % (img.size(1), str(img.size())) |
| | img_perm = img.permute(0,2,3,1) |
| | mean = torch.Tensor(mean) |
| | std = torch.Tensor(std) |
| | img_res = img_perm * std + mean |
| | return img_res.permute(0,3,1,2) |
| |
|
| | class ColorMap(): |
| | def __init__(self, basergb=[255,255,0]): |
| | self.basergb = np.array(basergb) |
| | def __call__(self, attnmap): |
| | |
| | |
| | assert attnmap.dtype == np.uint8 |
| | h, w = attnmap.shape |
| | res = self.basergb.copy() |
| | res = res[None][None].repeat(h, 0).repeat(w, 1) |
| | attn1 = attnmap.copy()[..., None] |
| | res = np.concatenate((res, attn1), axis=-1).astype(np.uint8) |
| | return res |
| |
|
| |
|
| | class COCOVisualizer(): |
| | def __init__(self) -> None: |
| | pass |
| |
|
| | def visualize(self, img, tgt, caption=None, dpi=120, savedir=None, show_in_console=True): |
| | """ |
| | img: tensor(3, H, W) |
| | tgt: make sure they are all on cpu. |
| | must have items: 'image_id', 'boxes', 'size' |
| | """ |
| | plt.figure(dpi=dpi) |
| | plt.rcParams['font.size'] = '5' |
| | ax = plt.gca() |
| | img = renorm(img).permute(1, 2, 0) |
| | ax.imshow(img) |
| | |
| | self.addtgt(tgt) |
| | if show_in_console: |
| | plt.show() |
| |
|
| | if savedir is not None: |
| | if caption is None: |
| | savename = '{}/{}-{}.png'.format(savedir, int(tgt['image_id']), str(datetime.datetime.now()).replace(' ', '-')) |
| | else: |
| | savename = '{}/{}-{}-{}.png'.format(savedir, caption, int(tgt['image_id']), str(datetime.datetime.now()).replace(' ', '-')) |
| | print("savename: {}".format(savename)) |
| | os.makedirs(os.path.dirname(savename), exist_ok=True) |
| | plt.savefig(savename) |
| | plt.close() |
| |
|
| | def addtgt(self, tgt): |
| | """ |
| | - tgt: dict. args: |
| | - boxes: num_boxes, 4. xywh, [0,1]. |
| | - box_label: num_boxes. |
| | """ |
| | assert 'boxes' in tgt |
| | ax = plt.gca() |
| | H, W = tgt['size'].tolist() |
| | numbox = tgt['boxes'].shape[0] |
| |
|
| | color = [] |
| | polygons = [] |
| | boxes = [] |
| | for box in tgt['boxes'].cpu(): |
| | unnormbbox = box * torch.Tensor([W, H, W, H]) |
| | unnormbbox[:2] -= unnormbbox[2:] / 2 |
| | [bbox_x, bbox_y, bbox_w, bbox_h] = unnormbbox.tolist() |
| | boxes.append([bbox_x, bbox_y, bbox_w, bbox_h]) |
| | poly = [[bbox_x, bbox_y], [bbox_x, bbox_y+bbox_h], [bbox_x+bbox_w, bbox_y+bbox_h], [bbox_x+bbox_w, bbox_y]] |
| | np_poly = np.array(poly).reshape((4,2)) |
| | polygons.append(Polygon(np_poly)) |
| | c = (np.random.random((1, 3))*0.6+0.4).tolist()[0] |
| | color.append(c) |
| |
|
| | p = PatchCollection(polygons, facecolor=color, linewidths=0, alpha=0.1) |
| | ax.add_collection(p) |
| | p = PatchCollection(polygons, facecolor='none', edgecolors=color, linewidths=2) |
| | ax.add_collection(p) |
| |
|
| |
|
| | if 'box_label' in tgt: |
| | assert len(tgt['box_label']) == numbox, f"{len(tgt['box_label'])} = {numbox}, " |
| | for idx, bl in enumerate(tgt['box_label']): |
| | _string = str(bl) |
| | bbox_x, bbox_y, bbox_w, bbox_h = boxes[idx] |
| | |
| | ax.text(bbox_x, bbox_y, _string, color='black', bbox={'facecolor': color[idx], 'alpha': 0.6, 'pad': 1}) |
| |
|
| | if 'caption' in tgt: |
| | ax.set_title(tgt['caption'], wrap=True) |
| |
|
| |
|
| |
|