File size: 2,076 Bytes
c6dfc69
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
"""Optional Weights & Biases logging for Ref-AVS training."""
import os

import torchvision
import wandb


class Tensorboard:
    def __init__(self, config):
        key = config.get('wandb_key') or os.environ.get('WANDB_API_KEY', '')
        if key:
            os.environ['WANDB_API_KEY'] = key
        mode = 'online' if config.get('wandb_online', False) else 'disabled'
        self.tensor_board = wandb.init(
            project=config['proj_name'],
            name=config['experiment_name'],
            config=config,
            mode=mode,
            settings=wandb.Settings(code_dir='.'),
        )
        self.restore_transform = torchvision.transforms.Compose([
            DeNormalize(config['image_mean'], config['image_std']),
            torchvision.transforms.ToPILImage(),
        ])

    def upload_wandb_info(self, info_dict):
        for key, value in info_dict.items():
            self.tensor_board.log({key: value})

    def upload_wandb_image(self, frames, pseudo_label_from_pred, pseudo_label_from_sam, img_number=4):
        n = min(pseudo_label_from_pred.shape[0], img_number)
        frames = frames[:n]
        pseudo_label_from_sam = pseudo_label_from_sam[:n].float()
        pseudo_label_from_pred = pseudo_label_from_pred[:n].float()
        pseudo_label_from_sam[pseudo_label_from_sam == 255.] = 0.5
        pseudo_label_from_pred[pseudo_label_from_pred == 255.] = 0.5
        self.tensor_board.log({
            'image': [wandb.Image(j, caption=f'id {i}') for i, j in enumerate(frames)],
            'label': [wandb.Image(j.squeeze(), caption=f'id {i}') for i, j in enumerate(pseudo_label_from_sam)],
            'logits': [wandb.Image(j.squeeze(), caption=f'id {i}') for i, j in enumerate(pseudo_label_from_pred)],
        })

    def finish(self):
        self.tensor_board.finish()


class DeNormalize(object):
    def __init__(self, mean, std):
        self.mean = mean
        self.std = std

    def __call__(self, tensor):
        for t, m, s in zip(tensor, self.mean, self.std):
            t.mul_(s).add_(m)
        return tensor