File size: 4,031 Bytes
04c78c7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
from pathlib import Path

import torch
import pytorch_lightning as pl
from torchvision.utils import make_grid, save_image
from torchvision.transforms import Resize

from capture.render import encode_as_unit_interval, gamma_encode


class VisualizeCallback(pl.Callback):
    def __init__(self, exist_ok: bool, out_dir: Path, log_every_n_epoch: int, n_batches_shown: int):
        super().__init__()

        self.out_dir = out_dir/'images_1'
        if not exist_ok and (self.out_dir.is_dir() and len(list(self.out_dir.iterdir())) > 0):
            print(f'directory {out_dir} already exists, press \'y\' to proceed')
            x = input()
            if x != 'y':
                exit(1)

        self.out_dir.mkdir(parents=True, exist_ok=True)

        self.log_every_n_epoch = log_every_n_epoch
        self.n_batches_shown = n_batches_shown
        self.resize = Resize(size=[128,128], antialias=True)

    def setup(self, trainer, module, stage):
        self.logger = trainer.logger

    def on_train_batch_end(self, *args):
        self._on_batch_end(*args, split='train')

    def on_validation_batch_end(self, *args):
        self._on_batch_end(*args, split='valid')

    def _on_batch_end(self, trainer, module, outputs, inputs, batch, *args, split):
        x_src, x_tgt = inputs

        # optim_idx:0=discr & optim_idx:1=generator
        y_src, y_tgt = outputs[1]['y'] if isinstance(outputs, list) else outputs['y']

        epoch = trainer.current_epoch
        if epoch % self.log_every_n_epoch == 0 and batch <= self.n_batches_shown:
            if x_src and y_src:
                self._visualize_src(x_src, y_src, split=split, epoch=epoch, batch=batch, ds='src')
            if x_tgt and y_tgt:
                self._visualize_tgt(x_tgt, y_tgt, split=split, epoch=epoch, batch=batch, ds='tgt')

    def _visualize_src(self, x, y, split, epoch, batch, ds):
        #zipped = zip(x.albedo, x.roughness, x.normals, x.displacement, x.input, x.image,
        #             y.albedo, y.roughness, y.normals, y.displacement, y.reco, y.image)

        zipped = zip(x.albedo, x.roughness, x.normals, x.input, x.image,
                     y.albedo, y.roughness, y.normals, y.reco, y.image)

        grid = [self._visualize_single_src(*z) for z in zipped]

        name = self.out_dir/f'{split}{epoch:05d}_{ds}_{batch}.jpg'
        save_image(grid, name, nrow=1, padding=5)

    @torch.no_grad()
    def _visualize_single_src(self, a, r, n, input, mv, a_p, r_p, n_p, reco, mv_p):
        n = encode_as_unit_interval(n)
        n_p = encode_as_unit_interval(n_p)

        mv_gt = [gamma_encode(o) for o in mv]
        mv_pred = [gamma_encode(o) for o in mv_p]
        reco = gamma_encode(reco)

        maps = [input, a, r, n] + mv_gt + [reco, a_p, r_p, n_p] + mv_pred
        maps = [self.resize(x.cpu()) for x in maps]
        return make_grid(maps, nrow=len(maps) // 2, padding=0)

    @torch.no_grad()
    def _visualize_single_src_previous(self, a, r, n, d, input, mv, a_p, r_p, n_p, d_p, reco, mv_p):
        n = encode_as_unit_interval(n)
        n_p = encode_as_unit_interval(n_p)

        mv_gt = [gamma_encode(o) for o in mv]
        mv_pred = [gamma_encode(o) for o in mv_p]
        reco = gamma_encode(reco)

        maps = [input, a, r, n, d] + mv_gt + [reco, a_p, r_p, n_p, d_p] + mv_pred
        maps = [self.resize(x.cpu()) for x in maps]
        return make_grid(maps, nrow=len(maps)//2, padding=0)

    def _visualize_tgt(self, x, y, split, epoch, batch, ds):
        zipped = zip(x.input, y.albedo, y.roughness, y.normals, y.displacement)

        grid = [self._visualize_single_tgt(*z) for z in zipped]

        name = self.out_dir/f'{split}{epoch:05d}_{ds}_{batch}.jpg'
        save_image(grid, name, nrow=1, padding=5)

    @torch.no_grad()
    def _visualize_single_tgt(self, input, a_p, r_p, n_p, d_p):
        n_p = encode_as_unit_interval(n_p)
        maps = [input, a_p, r_p, n_p, d_p]
        maps = [self.resize(x.cpu()) for x in maps]
        return make_grid(maps, nrow=len(maps), padding=0)