| |
| ''' |
| @File : visualizer.py |
| @Time : 2022/04/05 11:39:33 |
| @Author : Shilong Liu |
| @Contact : liusl20@mail.tsinghua.edu.cn; slongliu86@gmail.com |
| Modified from COCO evaluator |
| ''' |
|
|
| import os, sys |
| from textwrap import wrap |
| import torch |
| import numpy as np |
| import cv2 |
| import datetime |
|
|
| import matplotlib.pyplot as plt |
| from matplotlib.collections import PatchCollection |
| from matplotlib.patches import Polygon |
| from pycocotools import mask as maskUtils |
| from matplotlib import transforms |
|
|
| def renorm(img: torch.FloatTensor, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) \ |
| -> torch.FloatTensor: |
| |
| |
| assert img.dim() == 3 or img.dim() == 4, "img.dim() should be 3 or 4 but %d" % img.dim() |
| if img.dim() == 3: |
| assert img.size(0) == 3, 'img.size(0) shoule be 3 but "%d". (%s)' % (img.size(0), str(img.size())) |
| img_perm = img.permute(1,2,0) |
| mean = torch.Tensor(mean) |
| std = torch.Tensor(std) |
| img_res = img_perm * std + mean |
| return img_res.permute(2,0,1) |
| else: |
| assert img.size(1) == 3, 'img.size(1) shoule be 3 but "%d". (%s)' % (img.size(1), str(img.size())) |
| img_perm = img.permute(0,2,3,1) |
| mean = torch.Tensor(mean) |
| std = torch.Tensor(std) |
| img_res = img_perm * std + mean |
| return img_res.permute(0,3,1,2) |
|
|
| class ColorMap(): |
| def __init__(self, basergb=[255,255,0]): |
| self.basergb = np.array(basergb) |
| def __call__(self, attnmap): |
| |
| |
| assert attnmap.dtype == np.uint8 |
| h, w = attnmap.shape |
| res = self.basergb.copy() |
| res = res[None][None].repeat(h, 0).repeat(w, 1) |
| attn1 = attnmap.copy()[..., None] |
| res = np.concatenate((res, attn1), axis=-1).astype(np.uint8) |
| return res |
|
|
|
|
| class COCOVisualizer(): |
| def __init__(self) -> None: |
| pass |
|
|
| def visualize(self, img, tgt, caption=None, dpi=120, savedir=None, show_in_console=True): |
| """ |
| img: tensor(3, H, W) |
| tgt: make sure they are all on cpu. |
| must have items: 'image_id', 'boxes', 'size' |
| """ |
| plt.figure(dpi=dpi) |
| plt.rcParams['font.size'] = '5' |
| ax = plt.gca() |
| img = renorm(img).permute(1, 2, 0) |
| ax.imshow(img) |
| |
| self.addtgt(tgt) |
| if show_in_console: |
| plt.show() |
|
|
| if savedir is not None: |
| if caption is None: |
| savename = '{}/{}-{}.png'.format(savedir, int(tgt['image_id']), str(datetime.datetime.now()).replace(' ', '-')) |
| else: |
| savename = '{}/{}-{}-{}.png'.format(savedir, caption, int(tgt['image_id']), str(datetime.datetime.now()).replace(' ', '-')) |
| print("savename: {}".format(savename)) |
| os.makedirs(os.path.dirname(savename), exist_ok=True) |
| plt.savefig(savename) |
| plt.close() |
|
|
| def addtgt(self, tgt): |
| """ |
| - tgt: dict. args: |
| - boxes: num_boxes, 4. xywh, [0,1]. |
| - box_label: num_boxes. |
| """ |
| assert 'boxes' in tgt |
| ax = plt.gca() |
| H, W = tgt['size'].tolist() |
| numbox = tgt['boxes'].shape[0] |
|
|
| color = [] |
| polygons = [] |
| boxes = [] |
| for box in tgt['boxes'].cpu(): |
| unnormbbox = box * torch.Tensor([W, H, W, H]) |
| unnormbbox[:2] -= unnormbbox[2:] / 2 |
| [bbox_x, bbox_y, bbox_w, bbox_h] = unnormbbox.tolist() |
| boxes.append([bbox_x, bbox_y, bbox_w, bbox_h]) |
| poly = [[bbox_x, bbox_y], [bbox_x, bbox_y+bbox_h], [bbox_x+bbox_w, bbox_y+bbox_h], [bbox_x+bbox_w, bbox_y]] |
| np_poly = np.array(poly).reshape((4,2)) |
| polygons.append(Polygon(np_poly)) |
| c = (np.random.random((1, 3))*0.6+0.4).tolist()[0] |
| color.append(c) |
|
|
| p = PatchCollection(polygons, facecolor=color, linewidths=0, alpha=0.1) |
| ax.add_collection(p) |
| p = PatchCollection(polygons, facecolor='none', edgecolors=color, linewidths=2) |
| ax.add_collection(p) |
|
|
|
|
| if 'box_label' in tgt: |
| assert len(tgt['box_label']) == numbox, f"{len(tgt['box_label'])} = {numbox}, " |
| for idx, bl in enumerate(tgt['box_label']): |
| _string = str(bl) |
| bbox_x, bbox_y, bbox_w, bbox_h = boxes[idx] |
| |
| ax.text(bbox_x, bbox_y, _string, color='black', bbox={'facecolor': color[idx], 'alpha': 0.6, 'pad': 1}) |
|
|
| if 'caption' in tgt: |
| ax.set_title(tgt['caption'], wrap=True) |
|
|
|
|
|
|