File size: 5,468 Bytes
04c78c7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
import torch
import torch.nn as nn
from torch import optim
from pathlib import Path
from easydict import EasyDict
import pytorch_lightning as pl
import torch.nn.functional as F
import torchvision.transforms as T
from torchvision.utils import save_image
from torchmetrics import MeanSquaredError, StructuralSimilarityIndexMeasure
#from pytorch_lightning.loggers import TensorBoardLogger, WandbLogger

from . import DenseReg, RenderingLoss
from ..render import Renderer, encode_as_unit_interval, gamma_decode, gamma_encode

class Vanilla(pl.LightningModule):
    metrics = ['I_mse','N_mse','A_mse','R_mse','I_ssim','N_ssim','A_ssim','R_ssim']
    maps = {'I': 'reco', 'N': 'normals', 'R': 'roughness', 'A': 'albedo'}

    def __init__(self, model: nn.Module, loss: DenseReg = None, lr: float = 0, batch_size: int = 0, max_images: int = 10):
        super().__init__()
        self.model = model
        self.loss = loss
        self.lr = lr
        self.batch_size = batch_size
        self.tanh = nn.Tanh()
        self.norm = T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        self.val_images = []
        self.max_images = max_images

        self.save_hyperparameters(ignore=['model', 'loss'])

    def training_step(self, x):
        y = self(*x)
        loss = self.loss(x, y)
        self.log_to('train', loss)
        return dict(loss=loss.total, y=y)

    def forward(self, src, tgt):
        src_out, tgt_out = None, None

        if None not in src:
            src_out = self.model(self.norm(src.input))
            self.post_process_(src_out)

        if None not in tgt:
            tgt_out = self.model(self.norm(tgt.input))
            self.post_process_(tgt_out)

        return src_out, tgt_out

    def post_process_(self, o: EasyDict):
        # (1) activation function, (2) concat unit z, (3) normalize to unit vector
        nxy = self.tanh(o.normals)
        nx, ny = torch.split(nxy*3, split_size_or_sections=1, dim=1)
        n = torch.cat([nx, ny, torch.ones_like(nx)], dim=1)
        o.normals = F.normalize(n, dim=1)

        # (1) activation function, (2) mapping [-1,1]->[0,1]
        a = self.tanh(o.albedo)
        o.albedo = encode_as_unit_interval(a)

        # (1) activation function, (2) mapping [-1,1]->[0,1], (3) channel repeat x3
        r = self.tanh(o.roughness)
        o.roughness = encode_as_unit_interval(r.repeat(1,3,1,1))

    def validation_step(self, x, *_):
        y = self(*x)
        loss = self.loss(x, y)
        self.log_to('val', loss)

        if len(self.val_images) * self.batch_size < self.max_images and self.logger:
            self.val_images.append((x, y))

        return dict(loss=loss.total, y=y)

    def on_validation_epoch_end(self):
        cur_ind = 0
        for ind, (x, y) in enumerate(self.val_images):
            for key in ['normals', 'albedo', 'roughness']:
                for i in range(len(x[0][key])):
                    pred_image = y[0][key][i]
                    gt_image = x[0][key][i]

                    self.logger.experiment.add_image(f'val/{cur_ind}_{key}_pred', pred_image, self.global_step)
                    self.logger.experiment.add_image(f'val/{cur_ind}_{key}_gt', gt_image, self.global_step)
                cur_ind += 1
                if cur_ind >= self.max_images:
                    break

        self.val_images.clear()

    def log_to(self, split, loss):
        self.log_dict({f'{split}/{k}': v for k, v in loss.items()}, batch_size=self.batch_size)

    def log_images(self, x, y, split, max_images=5):
        # Log predicted and ground truth images
        for key in ['normals', 'albedo', 'roughness']:
            for i in range(min(len(x[0][key]), max_images)):
                self.logger.experiment.add_image(f'{split}/{key}_pred_{i}', y[0][key][i], self.global_step)
                self.logger.experiment.add_image(f'{split}/{key}_gt_{i}', x[0][key][i], self.global_step)


    def on_test_start(self):
        self.renderer = RenderingLoss(Renderer())

        for m in Vanilla.metrics:
            if 'mse' in m:
                setattr(self, m, MeanSquaredError().to(self.device))
            elif 'ssim' in m:
                setattr(self, m, StructuralSimilarityIndexMeasure(data_range=1).to(self.device))

    def test_step(self, x, batch_idx, dl_id=0):
        y = self.model(self.norm(x.input))
        self.post_process_(y)

        # image reconstruction
        y.reco = self.renderer.reconstruction(y, x.input_params)
        x.reco = gamma_decode(x.input)

        for m in Vanilla.metrics:
            mapid, *_ = m
            k = Vanilla.maps[mapid]
            meter = getattr(self, m)
            meter(y[k], x[k].to(y[k].dtype))
            self.log(m, getattr(self, m), on_epoch=True)

        if self.logger:
            self.log_images(x, y, split='test')

    def predict_step(self, x, batch_idx):
        y = self.model(self.norm(x.input))
        self.post_process_(y)

        I, name, outdir = x.input[0], x.name[0], Path(x.path[0]).parent
        N_pred, A_pred, R_pred  = y.normals[0], y.albedo[0], y.roughness[0]

        save_image(gamma_encode(A_pred), outdir/f'{name}_albedo.png')
        save_image(encode_as_unit_interval(N_pred), outdir/f'{name}_normals.png')
        save_image(R_pred, outdir/f'{name}_roughness.png')

    def configure_optimizers(self):
        optimizer = optim.Adam(self.model.parameters(), lr=self.lr, weight_decay=1e-4)
        return dict(optimizer=optimizer)