File size: 5,468 Bytes
04c78c7 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 | import torch
import torch.nn as nn
from torch import optim
from pathlib import Path
from easydict import EasyDict
import pytorch_lightning as pl
import torch.nn.functional as F
import torchvision.transforms as T
from torchvision.utils import save_image
from torchmetrics import MeanSquaredError, StructuralSimilarityIndexMeasure
#from pytorch_lightning.loggers import TensorBoardLogger, WandbLogger
from . import DenseReg, RenderingLoss
from ..render import Renderer, encode_as_unit_interval, gamma_decode, gamma_encode
class Vanilla(pl.LightningModule):
metrics = ['I_mse','N_mse','A_mse','R_mse','I_ssim','N_ssim','A_ssim','R_ssim']
maps = {'I': 'reco', 'N': 'normals', 'R': 'roughness', 'A': 'albedo'}
def __init__(self, model: nn.Module, loss: DenseReg = None, lr: float = 0, batch_size: int = 0, max_images: int = 10):
super().__init__()
self.model = model
self.loss = loss
self.lr = lr
self.batch_size = batch_size
self.tanh = nn.Tanh()
self.norm = T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
self.val_images = []
self.max_images = max_images
self.save_hyperparameters(ignore=['model', 'loss'])
def training_step(self, x):
y = self(*x)
loss = self.loss(x, y)
self.log_to('train', loss)
return dict(loss=loss.total, y=y)
def forward(self, src, tgt):
src_out, tgt_out = None, None
if None not in src:
src_out = self.model(self.norm(src.input))
self.post_process_(src_out)
if None not in tgt:
tgt_out = self.model(self.norm(tgt.input))
self.post_process_(tgt_out)
return src_out, tgt_out
def post_process_(self, o: EasyDict):
# (1) activation function, (2) concat unit z, (3) normalize to unit vector
nxy = self.tanh(o.normals)
nx, ny = torch.split(nxy*3, split_size_or_sections=1, dim=1)
n = torch.cat([nx, ny, torch.ones_like(nx)], dim=1)
o.normals = F.normalize(n, dim=1)
# (1) activation function, (2) mapping [-1,1]->[0,1]
a = self.tanh(o.albedo)
o.albedo = encode_as_unit_interval(a)
# (1) activation function, (2) mapping [-1,1]->[0,1], (3) channel repeat x3
r = self.tanh(o.roughness)
o.roughness = encode_as_unit_interval(r.repeat(1,3,1,1))
def validation_step(self, x, *_):
y = self(*x)
loss = self.loss(x, y)
self.log_to('val', loss)
if len(self.val_images) * self.batch_size < self.max_images and self.logger:
self.val_images.append((x, y))
return dict(loss=loss.total, y=y)
def on_validation_epoch_end(self):
cur_ind = 0
for ind, (x, y) in enumerate(self.val_images):
for key in ['normals', 'albedo', 'roughness']:
for i in range(len(x[0][key])):
pred_image = y[0][key][i]
gt_image = x[0][key][i]
self.logger.experiment.add_image(f'val/{cur_ind}_{key}_pred', pred_image, self.global_step)
self.logger.experiment.add_image(f'val/{cur_ind}_{key}_gt', gt_image, self.global_step)
cur_ind += 1
if cur_ind >= self.max_images:
break
self.val_images.clear()
def log_to(self, split, loss):
self.log_dict({f'{split}/{k}': v for k, v in loss.items()}, batch_size=self.batch_size)
def log_images(self, x, y, split, max_images=5):
# Log predicted and ground truth images
for key in ['normals', 'albedo', 'roughness']:
for i in range(min(len(x[0][key]), max_images)):
self.logger.experiment.add_image(f'{split}/{key}_pred_{i}', y[0][key][i], self.global_step)
self.logger.experiment.add_image(f'{split}/{key}_gt_{i}', x[0][key][i], self.global_step)
def on_test_start(self):
self.renderer = RenderingLoss(Renderer())
for m in Vanilla.metrics:
if 'mse' in m:
setattr(self, m, MeanSquaredError().to(self.device))
elif 'ssim' in m:
setattr(self, m, StructuralSimilarityIndexMeasure(data_range=1).to(self.device))
def test_step(self, x, batch_idx, dl_id=0):
y = self.model(self.norm(x.input))
self.post_process_(y)
# image reconstruction
y.reco = self.renderer.reconstruction(y, x.input_params)
x.reco = gamma_decode(x.input)
for m in Vanilla.metrics:
mapid, *_ = m
k = Vanilla.maps[mapid]
meter = getattr(self, m)
meter(y[k], x[k].to(y[k].dtype))
self.log(m, getattr(self, m), on_epoch=True)
if self.logger:
self.log_images(x, y, split='test')
def predict_step(self, x, batch_idx):
y = self.model(self.norm(x.input))
self.post_process_(y)
I, name, outdir = x.input[0], x.name[0], Path(x.path[0]).parent
N_pred, A_pred, R_pred = y.normals[0], y.albedo[0], y.roughness[0]
save_image(gamma_encode(A_pred), outdir/f'{name}_albedo.png')
save_image(encode_as_unit_interval(N_pred), outdir/f'{name}_normals.png')
save_image(R_pred, outdir/f'{name}_roughness.png')
def configure_optimizers(self):
optimizer = optim.Adam(self.model.parameters(), lr=self.lr, weight_decay=1e-4)
return dict(optimizer=optimizer)
|