File size: 5,178 Bytes
c6dfc69
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
import os

import PIL
import matplotlib.pyplot as plt
import numpy
import torch
import torchvision
try:
    import wandb
except ImportError:  # pragma: no cover
    wandb = None

# from utils.visualize import show_img


color_map = {"background": (0, 0, 0), "longitudinal": (128, 0, 0), "pothole": (0, 128, 0),
             "alligator": (128, 128, 0),  "transverse": (128, 0, 128), "ignore": (255, 255, 255)}


class _DummyWandb:
    def log(self, *args, **kwargs):
        return None


class Tensorboard:
    def __init__(self, config):
        self._log_images = bool(config.get('wandb_online', False))
        if not self._log_images or wandb is None or not hasattr(wandb, "init"):
            self.tensor_board = _DummyWandb()
            self._log_images = False
        elif config.get('wandb_online', False):
            key = config.get('wandb_key') or os.environ.get('WANDB_API_KEY', '')
            if key:
                os.environ['WANDB_API_KEY'] = key
                wandb.login(key=key, relogin=False)
            self.tensor_board = wandb.init(project=config['proj_name'], name=config['experiment_name'],
                                           config=config, settings=wandb.Settings(code_dir="."))

        self.restore_transform = torchvision.transforms.Compose([
            DeNormalize(config['image_mean'], config['image_std']),
            torchvision.transforms.ToPILImage()])

    def upload_wandb_info(self, info_dict):
        for i, info in enumerate(info_dict):
            self.tensor_board.log({info: info_dict[info]})
        return


    def upload_wandb_image(self, frames, pseudo_label_from_pred, pseudo_label_from_sam, img_number=4):
        if not self._log_images:
            return

        def _batched_rgb(t):
            """[N,C,H,W] or [C,H,W] float tensor on CPU."""
            if not isinstance(t, torch.Tensor):
                t = torch.as_tensor(t)
            t = t.detach().cpu().float()
            if t.dim() == 3:
                return t.unsqueeze(0)
            if t.dim() == 4:
                return t
            raise ValueError("frames must be [C,H,W] or [N,C,H,W], got shape {}".format(tuple(t.shape)))

        def _batched_mask(t):
            """[N,H,W] or [N,1,H,W] or [H,W]."""
            if not isinstance(t, torch.Tensor):
                t = torch.as_tensor(t)
            t = t.detach().cpu().float()
            while t.dim() > 3:
                t = t.squeeze(1)
            if t.dim() == 2:
                t = t.unsqueeze(0)
            if t.dim() != 3:
                raise ValueError("masks must be [H,W], [N,H,W] or [N,1,H,W], got shape {}".format(tuple(t.shape)))
            return t

        frames = _batched_rgb(frames)
        pseudo_label_from_pred = _batched_mask(pseudo_label_from_pred)
        pseudo_label_from_sam = _batched_mask(pseudo_label_from_sam)

        n = min(frames.shape[0], pseudo_label_from_pred.shape[0], pseudo_label_from_sam.shape[0], img_number)
        frames = frames[:n]
        pseudo_label_from_pred = pseudo_label_from_pred[:n]
        pseudo_label_from_sam = pseudo_label_from_sam[:n]

        pseudo_label_from_sam = pseudo_label_from_sam.clone()
        pseudo_label_from_pred = pseudo_label_from_pred.clone()
        pseudo_label_from_sam[pseudo_label_from_sam == 255.] = 0.5
        pseudo_label_from_pred[pseudo_label_from_pred == 255.] = 0.5

        denorm = self.restore_transform.transforms[0]
        image_list = []
        label_list = []
        logits_list = []
        for i in range(n):
            fi = frames[i].clone()
            if fi.shape[0] == 3:
                denorm(fi)
                fi.clamp_(0.0, 1.0)
            image_list.append(wandb.Image(fi, caption="id {}".format(str(i))))
            # wandb.Image expects torch tensors as [C, H, W] (it permutes CHW→HWC)
            ms = pseudo_label_from_sam[i].squeeze()
            mp = pseudo_label_from_pred[i].squeeze()
            if ms.dim() == 2:
                ms = ms.unsqueeze(0)
            if mp.dim() == 2:
                mp = mp.unsqueeze(0)
            label_list.append(wandb.Image(ms, caption="id {}".format(str(i))))
            logits_list.append(wandb.Image(mp, caption="id {}".format(str(i))))

        self.tensor_board.log({"image": image_list, "label": label_list, "logits": logits_list})

    def de_normalize(self, image):
        return [self.restore_transform(i.detach().cpu()) if (isinstance(i, torch.Tensor) and len(i.shape) == 3)
                else colorize_mask(i.detach().cpu().numpy(), self.palette)
                for i in image]

    def finish(self):
        self.tensor_board.finish()


class DeNormalize(object):
    def __init__(self, mean, std):
        self.mean = mean
        self.std = std

    def __call__(self, tensor):
        for t, m, s in zip(tensor, self.mean, self.std):
            t.mul_(s).add_(m)
        return tensor


def colorize_mask(mask, palette):
    zero_pad = 256 * 3 - len(palette)
    for i in range(zero_pad):
        palette.append(0)
    # palette[-6:-3] = [183, 65, 14]
    new_mask = PIL.Image.fromarray(mask.astype(numpy.uint8)).convert('P')
    new_mask.putpalette(palette)
    return new_mask