Image-to-Video
File size: 2,141 Bytes
ef296aa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'

import torch
from torchvision.utils import make_grid
from torch.utils.tensorboard import SummaryWriter

from dot.utils.plot import to_rgb


def detach(tensor):
    if isinstance(tensor, torch.Tensor):
        return tensor.detach().cpu()
    return tensor


def number(tensor):
    if isinstance(tensor, torch.Tensor) and tensor.isnan().any():
        return torch.zeros_like(tensor)
    return tensor


class Logger():
    def __init__(self, args):
        self.writer = SummaryWriter(args.log_path)
        self.factor = args.log_factor
        self.world_size = args.world_size

    def log_scalar(self, name, scalar, global_iter):
        if scalar is not None:
            if type(scalar) == list:
                for i, x in enumerate(scalar):
                    self.log_scalar(f"{name}_{i}", x, global_iter)
            else:
                self.writer.add_scalar(name, number(detach(scalar)), global_iter)

    def log_scalars(self, name, scalars, global_iter):
        for s in scalars:
            self.log_scalar(f"{name}/{s}", scalars[s], global_iter)

    def log_image(self, name, tensor, mode, nrow, global_iter, pos=None, occ=None):
        tensor = detach(tensor)
        tensor = to_rgb(tensor, mode, pos, occ)
        grid = make_grid(tensor, nrow=nrow, normalize=False, value_range=[0, 1], pad_value=0)
        grid = torch.nn.functional.interpolate(grid[None], scale_factor=self.factor)[0]
        self.writer.add_image(name, grid, global_iter)

    def log_video(self, name, tensor, mode, nrow, global_iter, fps=4, pos=None, occ=None):
        tensor = detach(tensor)
        tensor = to_rgb(tensor, mode, pos, occ, is_video=True)
        grid = []
        for i in range(tensor.shape[1]):
            grid.append(make_grid(tensor[:, i], nrow=nrow, normalize=False, value_range=[0, 1], pad_value=0))
        grid = torch.stack(grid, dim=0)
        grid = torch.nn.functional.interpolate(grid, scale_factor=self.factor)
        grid = grid[None]
        self.writer.add_video(name, grid, global_iter, fps=fps)