File size: 3,793 Bytes
e168a4d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import numpy as np
import cv2
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from lib.utils.img_utils import colors
from lib.utils import img_utils
from lib.utils.snake import snake_cityscapes_utils, snake_config


R = 8
GREEN = (18, 127, 15)
WHITE = (255, 255, 255)


def visualize_snake_detection(img, data):

    def blend_hm_img(hm, img):
        hm = np.max(hm, axis=0)
        h, w = hm.shape[:2]
        img = cv2.resize(img, dsize=(w, h), interpolation=cv2.INTER_LINEAR)
        hm = np.array([255, 255, 255]) - (hm.reshape(h, w, 1) * colors[0]).astype(np.uint8)
        ratio = 0.5
        blend = (img * ratio + hm * (1 - ratio)).astype(np.uint8)
        return blend

    img = img_utils.bgr_to_rgb(img)
    blend = blend_hm_img(data['ct_hm'], img)

    plt.imshow(blend)
    ct_ind = np.array(data['ct_ind'])
    w = img.shape[1] // snake_config.down_ratio
    xs = ct_ind % w
    ys = ct_ind // w
    for i in range(len(data['wh'])):
        w, h = data['wh'][i]
        x_min, y_min = xs[i] - w / 2, ys[i] - h / 2
        x_max, y_max = xs[i] + w / 2, ys[i] + h / 2
        plt.plot([x_min, x_min, x_max, x_max, x_min], [y_min, y_max, y_max, y_min, y_min])
    plt.show()


def visualize_cp_detection(img, data):
    act_ind = data['act_ind']
    awh = data['awh']

    act_hm_w = data['act_hm'].shape[2]
    cp_h, cp_w = data['cp_hm'][0].shape[1], data['cp_hm'][0].shape[2]

    img = img_utils.bgr_to_rgb(img)
    plt.imshow(img)

    for i in range(len(act_ind)):
        act_ind_ = act_ind[i]
        ct = act_ind_ % act_hm_w, act_ind_ // act_hm_w
        w, h = awh[i]
        abox = np.array([ct[0] - w/2, ct[1] - h/2, ct[0] + w/2, ct[1] + h/2])

        cp_ind_ = data['cp_ind'][i]
        cp_wh_ = data['cp_wh'][i]

        for j in range(len(cp_ind_)):
            ct = cp_ind_[j] % cp_w, cp_ind_[j] // cp_w
            x = ct[0] / cp_w * w
            y = ct[1] / cp_h * h
            x_min = (x - cp_wh_[j][0] / 2 + abox[0]) * snake_config.down_ratio
            y_min = (y - cp_wh_[j][1] / 2 + abox[1]) * snake_config.down_ratio
            x_max = (x + cp_wh_[j][0] / 2 + abox[0]) * snake_config.down_ratio
            y_max = (y + cp_wh_[j][1] / 2 + abox[1]) * snake_config.down_ratio
            plt.plot([x_min, x_min, x_max, x_max, x_min], [y_min, y_max, y_max, y_min, y_min])

    plt.show()


def visualize_snake_evolution(img, data):
    img = img_utils.bgr_to_rgb(img)
    plt.imshow(img)
    for poly in data['i_gt_py']:
        poly = poly * 4
        poly = np.append(poly, [poly[0]], axis=0)
        plt.plot(poly[:, 0], poly[:, 1])
        plt.scatter(poly[0, 0], poly[0, 1], edgecolors='w')
    plt.show()


def visualize_snake_octagon(img, extreme_points):
    img = img_utils.bgr_to_rgb(img)
    octagons = []
    bboxes = []
    ex_points = []
    for i in range(len(extreme_points)):
        for j in range(len(extreme_points[i])):
            bbox = get_bbox(extreme_points[i][j]*4)
            octagon = snake_cityscapes_utils.get_octagon(extreme_points[i][j]*4)
            bboxes.append(bbox)
            octagons.append(octagon)
            ex_points.append(extreme_points[i][j])
    _, ax = plt.subplots(1)
    ax.imshow(img)
    n = len(octagons)
    for i in range(n):
        x, y, x_max, y_max = bboxes[i]
        ax.add_patch(patches.Polygon(xy=[[x, y], [x, y_max], [x_max, y_max], [x_max, y]], fill=False, linewidth=1,
                                     edgecolor='r'))
        octagon = np.append(octagons[i], octagons[i][0]).reshape(-1, 2)
        ax.plot(octagon[:, 0], octagon[:, 1])
        ax.scatter(ex_points[i][:, 0] * 4, ex_points[i][:, 1] * 4, edgecolors='w')
    plt.show()


def get_bbox(ex):
    x = ex[:, 0]
    y = ex[:, 1]
    bbox = [np.min(x), np.min(y), np.max(x), np.max(y)]
    return bbox