vrevar
Add application file
04c78c7
import torch
import torch.nn as nn
from torch import optim
from pathlib import Path
from easydict import EasyDict
import pytorch_lightning as pl
import torch.nn.functional as F
import torchvision.transforms as T
from torchvision.utils import save_image
from torchmetrics import MeanSquaredError, StructuralSimilarityIndexMeasure
#from pytorch_lightning.loggers import TensorBoardLogger, WandbLogger
from . import DenseReg, RenderingLoss
from ..render import Renderer, encode_as_unit_interval, gamma_decode, gamma_encode
class Vanilla(pl.LightningModule):
metrics = ['I_mse','N_mse','A_mse','R_mse','I_ssim','N_ssim','A_ssim','R_ssim']
maps = {'I': 'reco', 'N': 'normals', 'R': 'roughness', 'A': 'albedo'}
def __init__(self, model: nn.Module, loss: DenseReg = None, lr: float = 0, batch_size: int = 0, max_images: int = 10):
super().__init__()
self.model = model
self.loss = loss
self.lr = lr
self.batch_size = batch_size
self.tanh = nn.Tanh()
self.norm = T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
self.val_images = []
self.max_images = max_images
self.save_hyperparameters(ignore=['model', 'loss'])
def training_step(self, x):
y = self(*x)
loss = self.loss(x, y)
self.log_to('train', loss)
return dict(loss=loss.total, y=y)
def forward(self, src, tgt):
src_out, tgt_out = None, None
if None not in src:
src_out = self.model(self.norm(src.input))
self.post_process_(src_out)
if None not in tgt:
tgt_out = self.model(self.norm(tgt.input))
self.post_process_(tgt_out)
return src_out, tgt_out
def post_process_(self, o: EasyDict):
# (1) activation function, (2) concat unit z, (3) normalize to unit vector
nxy = self.tanh(o.normals)
nx, ny = torch.split(nxy*3, split_size_or_sections=1, dim=1)
n = torch.cat([nx, ny, torch.ones_like(nx)], dim=1)
o.normals = F.normalize(n, dim=1)
# (1) activation function, (2) mapping [-1,1]->[0,1]
a = self.tanh(o.albedo)
o.albedo = encode_as_unit_interval(a)
# (1) activation function, (2) mapping [-1,1]->[0,1], (3) channel repeat x3
r = self.tanh(o.roughness)
o.roughness = encode_as_unit_interval(r.repeat(1,3,1,1))
def validation_step(self, x, *_):
y = self(*x)
loss = self.loss(x, y)
self.log_to('val', loss)
if len(self.val_images) * self.batch_size < self.max_images and self.logger:
self.val_images.append((x, y))
return dict(loss=loss.total, y=y)
def on_validation_epoch_end(self):
cur_ind = 0
for ind, (x, y) in enumerate(self.val_images):
for key in ['normals', 'albedo', 'roughness']:
for i in range(len(x[0][key])):
pred_image = y[0][key][i]
gt_image = x[0][key][i]
self.logger.experiment.add_image(f'val/{cur_ind}_{key}_pred', pred_image, self.global_step)
self.logger.experiment.add_image(f'val/{cur_ind}_{key}_gt', gt_image, self.global_step)
cur_ind += 1
if cur_ind >= self.max_images:
break
self.val_images.clear()
def log_to(self, split, loss):
self.log_dict({f'{split}/{k}': v for k, v in loss.items()}, batch_size=self.batch_size)
def log_images(self, x, y, split, max_images=5):
# Log predicted and ground truth images
for key in ['normals', 'albedo', 'roughness']:
for i in range(min(len(x[0][key]), max_images)):
self.logger.experiment.add_image(f'{split}/{key}_pred_{i}', y[0][key][i], self.global_step)
self.logger.experiment.add_image(f'{split}/{key}_gt_{i}', x[0][key][i], self.global_step)
def on_test_start(self):
self.renderer = RenderingLoss(Renderer())
for m in Vanilla.metrics:
if 'mse' in m:
setattr(self, m, MeanSquaredError().to(self.device))
elif 'ssim' in m:
setattr(self, m, StructuralSimilarityIndexMeasure(data_range=1).to(self.device))
def test_step(self, x, batch_idx, dl_id=0):
y = self.model(self.norm(x.input))
self.post_process_(y)
# image reconstruction
y.reco = self.renderer.reconstruction(y, x.input_params)
x.reco = gamma_decode(x.input)
for m in Vanilla.metrics:
mapid, *_ = m
k = Vanilla.maps[mapid]
meter = getattr(self, m)
meter(y[k], x[k].to(y[k].dtype))
self.log(m, getattr(self, m), on_epoch=True)
if self.logger:
self.log_images(x, y, split='test')
def predict_step(self, x, batch_idx):
y = self.model(self.norm(x.input))
self.post_process_(y)
I, name, outdir = x.input[0], x.name[0], Path(x.path[0]).parent
N_pred, A_pred, R_pred = y.normals[0], y.albedo[0], y.roughness[0]
save_image(gamma_encode(A_pred), outdir/f'{name}_albedo.png')
save_image(encode_as_unit_interval(N_pred), outdir/f'{name}_normals.png')
save_image(R_pred, outdir/f'{name}_roughness.png')
def configure_optimizers(self):
optimizer = optim.Adam(self.model.parameters(), lr=self.lr, weight_decay=1e-4)
return dict(optimizer=optimizer)