| | import torch |
| | import torch.nn as nn |
| | from torch import optim |
| | from pathlib import Path |
| | from easydict import EasyDict |
| | import pytorch_lightning as pl |
| | import torch.nn.functional as F |
| | import torchvision.transforms as T |
| | from torchvision.utils import save_image |
| | from torchmetrics import MeanSquaredError, StructuralSimilarityIndexMeasure |
| | |
| |
|
| | from . import DenseReg, RenderingLoss |
| | from ..render import Renderer, encode_as_unit_interval, gamma_decode, gamma_encode |
| |
|
| | class Vanilla(pl.LightningModule): |
| | metrics = ['I_mse','N_mse','A_mse','R_mse','I_ssim','N_ssim','A_ssim','R_ssim'] |
| | maps = {'I': 'reco', 'N': 'normals', 'R': 'roughness', 'A': 'albedo'} |
| |
|
| | def __init__(self, model: nn.Module, loss: DenseReg = None, lr: float = 0, batch_size: int = 0, max_images: int = 10): |
| | super().__init__() |
| | self.model = model |
| | self.loss = loss |
| | self.lr = lr |
| | self.batch_size = batch_size |
| | self.tanh = nn.Tanh() |
| | self.norm = T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) |
| | self.val_images = [] |
| | self.max_images = max_images |
| |
|
| | self.save_hyperparameters(ignore=['model', 'loss']) |
| |
|
| | def training_step(self, x): |
| | y = self(*x) |
| | loss = self.loss(x, y) |
| | self.log_to('train', loss) |
| | return dict(loss=loss.total, y=y) |
| |
|
| | def forward(self, src, tgt): |
| | src_out, tgt_out = None, None |
| |
|
| | if None not in src: |
| | src_out = self.model(self.norm(src.input)) |
| | self.post_process_(src_out) |
| |
|
| | if None not in tgt: |
| | tgt_out = self.model(self.norm(tgt.input)) |
| | self.post_process_(tgt_out) |
| |
|
| | return src_out, tgt_out |
| |
|
| | def post_process_(self, o: EasyDict): |
| | |
| | nxy = self.tanh(o.normals) |
| | nx, ny = torch.split(nxy*3, split_size_or_sections=1, dim=1) |
| | n = torch.cat([nx, ny, torch.ones_like(nx)], dim=1) |
| | o.normals = F.normalize(n, dim=1) |
| |
|
| | |
| | a = self.tanh(o.albedo) |
| | o.albedo = encode_as_unit_interval(a) |
| |
|
| | |
| | r = self.tanh(o.roughness) |
| | o.roughness = encode_as_unit_interval(r.repeat(1,3,1,1)) |
| |
|
| | def validation_step(self, x, *_): |
| | y = self(*x) |
| | loss = self.loss(x, y) |
| | self.log_to('val', loss) |
| |
|
| | if len(self.val_images) * self.batch_size < self.max_images and self.logger: |
| | self.val_images.append((x, y)) |
| |
|
| | return dict(loss=loss.total, y=y) |
| |
|
| | def on_validation_epoch_end(self): |
| | cur_ind = 0 |
| | for ind, (x, y) in enumerate(self.val_images): |
| | for key in ['normals', 'albedo', 'roughness']: |
| | for i in range(len(x[0][key])): |
| | pred_image = y[0][key][i] |
| | gt_image = x[0][key][i] |
| |
|
| | self.logger.experiment.add_image(f'val/{cur_ind}_{key}_pred', pred_image, self.global_step) |
| | self.logger.experiment.add_image(f'val/{cur_ind}_{key}_gt', gt_image, self.global_step) |
| | cur_ind += 1 |
| | if cur_ind >= self.max_images: |
| | break |
| |
|
| | self.val_images.clear() |
| |
|
| | def log_to(self, split, loss): |
| | self.log_dict({f'{split}/{k}': v for k, v in loss.items()}, batch_size=self.batch_size) |
| |
|
| | def log_images(self, x, y, split, max_images=5): |
| | |
| | for key in ['normals', 'albedo', 'roughness']: |
| | for i in range(min(len(x[0][key]), max_images)): |
| | self.logger.experiment.add_image(f'{split}/{key}_pred_{i}', y[0][key][i], self.global_step) |
| | self.logger.experiment.add_image(f'{split}/{key}_gt_{i}', x[0][key][i], self.global_step) |
| |
|
| |
|
| | def on_test_start(self): |
| | self.renderer = RenderingLoss(Renderer()) |
| |
|
| | for m in Vanilla.metrics: |
| | if 'mse' in m: |
| | setattr(self, m, MeanSquaredError().to(self.device)) |
| | elif 'ssim' in m: |
| | setattr(self, m, StructuralSimilarityIndexMeasure(data_range=1).to(self.device)) |
| |
|
| | def test_step(self, x, batch_idx, dl_id=0): |
| | y = self.model(self.norm(x.input)) |
| | self.post_process_(y) |
| |
|
| | |
| | y.reco = self.renderer.reconstruction(y, x.input_params) |
| | x.reco = gamma_decode(x.input) |
| |
|
| | for m in Vanilla.metrics: |
| | mapid, *_ = m |
| | k = Vanilla.maps[mapid] |
| | meter = getattr(self, m) |
| | meter(y[k], x[k].to(y[k].dtype)) |
| | self.log(m, getattr(self, m), on_epoch=True) |
| |
|
| | if self.logger: |
| | self.log_images(x, y, split='test') |
| |
|
| | def predict_step(self, x, batch_idx): |
| | y = self.model(self.norm(x.input)) |
| | self.post_process_(y) |
| |
|
| | I, name, outdir = x.input[0], x.name[0], Path(x.path[0]).parent |
| | N_pred, A_pred, R_pred = y.normals[0], y.albedo[0], y.roughness[0] |
| |
|
| | save_image(gamma_encode(A_pred), outdir/f'{name}_albedo.png') |
| | save_image(encode_as_unit_interval(N_pred), outdir/f'{name}_normals.png') |
| | save_image(R_pred, outdir/f'{name}_roughness.png') |
| |
|
| | def configure_optimizers(self): |
| | optimizer = optim.Adam(self.model.parameters(), lr=self.lr, weight_decay=1e-4) |
| | return dict(optimizer=optimizer) |
| |
|