| |
| """ |
| @File : visualizer.py |
| @Time : 2022/04/05 11:39:33 |
| @Author : Shilong Liu |
| @Contact : slongliu86@gmail.com |
| """ |
|
|
| import datetime |
| import os |
|
|
| import cv2 |
| import matplotlib.pyplot as plt |
| import numpy as np |
| import torch |
| from matplotlib import transforms |
| from matplotlib.collections import PatchCollection |
| from matplotlib.patches import Polygon |
| from pycocotools import mask as maskUtils |
|
|
|
|
| def renorm( |
| img: torch.FloatTensor, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] |
| ) -> torch.FloatTensor: |
| |
| |
| assert img.dim() == 3 or img.dim() == 4, "img.dim() should be 3 or 4 but %d" % img.dim() |
| if img.dim() == 3: |
| assert img.size(0) == 3, 'img.size(0) shoule be 3 but "%d". (%s)' % ( |
| img.size(0), |
| str(img.size()), |
| ) |
| img_perm = img.permute(1, 2, 0) |
| mean = torch.Tensor(mean) |
| std = torch.Tensor(std) |
| img_res = img_perm * std + mean |
| return img_res.permute(2, 0, 1) |
| else: |
| assert img.size(1) == 3, 'img.size(1) shoule be 3 but "%d". (%s)' % ( |
| img.size(1), |
| str(img.size()), |
| ) |
| img_perm = img.permute(0, 2, 3, 1) |
| mean = torch.Tensor(mean) |
| std = torch.Tensor(std) |
| img_res = img_perm * std + mean |
| return img_res.permute(0, 3, 1, 2) |
|
|
|
|
| class ColorMap: |
| def __init__(self, basergb=[255, 255, 0]): |
| self.basergb = np.array(basergb) |
|
|
| def __call__(self, attnmap): |
| |
| |
| assert attnmap.dtype == np.uint8 |
| h, w = attnmap.shape |
| res = self.basergb.copy() |
| res = res[None][None].repeat(h, 0).repeat(w, 1) |
| attn1 = attnmap.copy()[..., None] |
| res = np.concatenate((res, attn1), axis=-1).astype(np.uint8) |
| return res |
|
|
|
|
| def rainbow_text(x, y, ls, lc, **kw): |
| """ |
| Take a list of strings ``ls`` and colors ``lc`` and place them next to each |
| other, with text ls[i] being shown in color lc[i]. |
| |
| This example shows how to do both vertical and horizontal text, and will |
| pass all keyword arguments to plt.text, so you can set the font size, |
| family, etc. |
| """ |
| t = plt.gca().transData |
| fig = plt.gcf() |
| plt.show() |
|
|
| |
| for s, c in zip(ls, lc): |
| text = plt.text(x, y, " " + s + " ", color=c, transform=t, **kw) |
| text.draw(fig.canvas.get_renderer()) |
| ex = text.get_window_extent() |
| t = transforms.offset_copy(text._transform, x=ex.width, units="dots") |
|
|
| |
| |
| |
| |
| |
| |
| |
|
|
|
|
| class COCOVisualizer: |
| def __init__(self, coco=None, tokenlizer=None) -> None: |
| self.coco = coco |
|
|
| def visualize(self, img, tgt, caption=None, dpi=180, savedir="vis"): |
| """ |
| img: tensor(3, H, W) |
| tgt: make sure they are all on cpu. |
| must have items: 'image_id', 'boxes', 'size' |
| """ |
| plt.figure(dpi=dpi) |
| plt.rcParams["font.size"] = "5" |
| ax = plt.gca() |
| img = renorm(img).permute(1, 2, 0) |
| |
| |
| ax.imshow(img) |
|
|
| self.addtgt(tgt) |
|
|
| if tgt is None: |
| image_id = 0 |
| elif "image_id" not in tgt: |
| image_id = 0 |
| else: |
| image_id = tgt["image_id"] |
|
|
| if caption is None: |
| savename = "{}/{}-{}.png".format( |
| savedir, int(image_id), str(datetime.datetime.now()).replace(" ", "-") |
| ) |
| else: |
| savename = "{}/{}-{}-{}.png".format( |
| savedir, caption, int(image_id), str(datetime.datetime.now()).replace(" ", "-") |
| ) |
| print("savename: {}".format(savename)) |
| os.makedirs(os.path.dirname(savename), exist_ok=True) |
| plt.savefig(savename) |
| plt.close() |
|
|
| def addtgt(self, tgt): |
| """ """ |
| if tgt is None or not "boxes" in tgt: |
| ax = plt.gca() |
|
|
| if "caption" in tgt: |
| ax.set_title(tgt["caption"], wrap=True) |
|
|
| ax.set_axis_off() |
| return |
|
|
| ax = plt.gca() |
| H, W = tgt["size"] |
| numbox = tgt["boxes"].shape[0] |
|
|
| color = [] |
| polygons = [] |
| boxes = [] |
| for box in tgt["boxes"].cpu(): |
| unnormbbox = box * torch.Tensor([W, H, W, H]) |
| unnormbbox[:2] -= unnormbbox[2:] / 2 |
| [bbox_x, bbox_y, bbox_w, bbox_h] = unnormbbox.tolist() |
| boxes.append([bbox_x, bbox_y, bbox_w, bbox_h]) |
| poly = [ |
| [bbox_x, bbox_y], |
| [bbox_x, bbox_y + bbox_h], |
| [bbox_x + bbox_w, bbox_y + bbox_h], |
| [bbox_x + bbox_w, bbox_y], |
| ] |
| np_poly = np.array(poly).reshape((4, 2)) |
| polygons.append(Polygon(np_poly)) |
| c = (np.random.random((1, 3)) * 0.6 + 0.4).tolist()[0] |
| color.append(c) |
|
|
| p = PatchCollection(polygons, facecolor=color, linewidths=0, alpha=0.1) |
| ax.add_collection(p) |
| p = PatchCollection(polygons, facecolor="none", edgecolors=color, linewidths=2) |
| ax.add_collection(p) |
|
|
| if "strings_positive" in tgt and len(tgt["strings_positive"]) > 0: |
| assert ( |
| len(tgt["strings_positive"]) == numbox |
| ), f"{len(tgt['strings_positive'])} = {numbox}, " |
| for idx, strlist in enumerate(tgt["strings_positive"]): |
| cate_id = int(tgt["labels"][idx]) |
| _string = str(cate_id) + ":" + " ".join(strlist) |
| bbox_x, bbox_y, bbox_w, bbox_h = boxes[idx] |
| |
| ax.text( |
| bbox_x, |
| bbox_y, |
| _string, |
| color="black", |
| bbox={"facecolor": color[idx], "alpha": 0.6, "pad": 1}, |
| ) |
|
|
| if "box_label" in tgt: |
| assert len(tgt["box_label"]) == numbox, f"{len(tgt['box_label'])} = {numbox}, " |
| for idx, bl in enumerate(tgt["box_label"]): |
| _string = str(bl) |
| bbox_x, bbox_y, bbox_w, bbox_h = boxes[idx] |
| |
| ax.text( |
| bbox_x, |
| bbox_y, |
| _string, |
| color="black", |
| bbox={"facecolor": color[idx], "alpha": 0.6, "pad": 1}, |
| ) |
|
|
| if "caption" in tgt: |
| ax.set_title(tgt["caption"], wrap=True) |
| |
| |
| |
|
|
| if "attn" in tgt: |
| |
| |
| if isinstance(tgt["attn"], tuple): |
| tgt["attn"] = [tgt["attn"]] |
| for item in tgt["attn"]: |
| attn_map, basergb = item |
| attn_map = (attn_map - attn_map.min()) / (attn_map.max() - attn_map.min() + 1e-3) |
| attn_map = (attn_map * 255).astype(np.uint8) |
| cm = ColorMap(basergb) |
| heatmap = cm(attn_map) |
| ax.imshow(heatmap) |
| ax.set_axis_off() |
|
|
| def showAnns(self, anns, draw_bbox=False): |
| """ |
| Display the specified annotations. |
| :param anns (array of object): annotations to display |
| :return: None |
| """ |
| if len(anns) == 0: |
| return 0 |
| if "segmentation" in anns[0] or "keypoints" in anns[0]: |
| datasetType = "instances" |
| elif "caption" in anns[0]: |
| datasetType = "captions" |
| else: |
| raise Exception("datasetType not supported") |
| if datasetType == "instances": |
| ax = plt.gca() |
| ax.set_autoscale_on(False) |
| polygons = [] |
| color = [] |
| for ann in anns: |
| c = (np.random.random((1, 3)) * 0.6 + 0.4).tolist()[0] |
| if "segmentation" in ann: |
| if type(ann["segmentation"]) == list: |
| |
| for seg in ann["segmentation"]: |
| poly = np.array(seg).reshape((int(len(seg) / 2), 2)) |
| polygons.append(Polygon(poly)) |
| color.append(c) |
| else: |
| |
| t = self.imgs[ann["image_id"]] |
| if type(ann["segmentation"]["counts"]) == list: |
| rle = maskUtils.frPyObjects( |
| [ann["segmentation"]], t["height"], t["width"] |
| ) |
| else: |
| rle = [ann["segmentation"]] |
| m = maskUtils.decode(rle) |
| img = np.ones((m.shape[0], m.shape[1], 3)) |
| if ann["iscrowd"] == 1: |
| color_mask = np.array([2.0, 166.0, 101.0]) / 255 |
| if ann["iscrowd"] == 0: |
| color_mask = np.random.random((1, 3)).tolist()[0] |
| for i in range(3): |
| img[:, :, i] = color_mask[i] |
| ax.imshow(np.dstack((img, m * 0.5))) |
| if "keypoints" in ann and type(ann["keypoints"]) == list: |
| |
| sks = np.array(self.loadCats(ann["category_id"])[0]["skeleton"]) - 1 |
| kp = np.array(ann["keypoints"]) |
| x = kp[0::3] |
| y = kp[1::3] |
| v = kp[2::3] |
| for sk in sks: |
| if np.all(v[sk] > 0): |
| plt.plot(x[sk], y[sk], linewidth=3, color=c) |
| plt.plot( |
| x[v > 0], |
| y[v > 0], |
| "o", |
| markersize=8, |
| markerfacecolor=c, |
| markeredgecolor="k", |
| markeredgewidth=2, |
| ) |
| plt.plot( |
| x[v > 1], |
| y[v > 1], |
| "o", |
| markersize=8, |
| markerfacecolor=c, |
| markeredgecolor=c, |
| markeredgewidth=2, |
| ) |
|
|
| if draw_bbox: |
| [bbox_x, bbox_y, bbox_w, bbox_h] = ann["bbox"] |
| poly = [ |
| [bbox_x, bbox_y], |
| [bbox_x, bbox_y + bbox_h], |
| [bbox_x + bbox_w, bbox_y + bbox_h], |
| [bbox_x + bbox_w, bbox_y], |
| ] |
| np_poly = np.array(poly).reshape((4, 2)) |
| polygons.append(Polygon(np_poly)) |
| color.append(c) |
|
|
| |
| |
| p = PatchCollection(polygons, facecolor="none", edgecolors=color, linewidths=2) |
| ax.add_collection(p) |
| elif datasetType == "captions": |
| for ann in anns: |
| print(ann["caption"]) |
|
|