| | from pathlib import Path |
| |
|
| | import torch |
| | import torch.nn as nn |
| | from easydict import EasyDict |
| | import torch.nn.functional as F |
| | import torchvision.transforms.functional as tf |
| |
|
| | from ..render import Renderer, Scene, generate_random_scenes, generate_specular_scenes |
| |
|
| |
|
| | class RenderingLoss(nn.Module): |
| | def __init__(self, renderer, n_random_configs=0, n_symmetric_configs=0): |
| | super().__init__() |
| | self.eps = 0.1 |
| | self.renderer = renderer |
| | self.n_random_configs = n_random_configs |
| | self.n_symmetric_configs = n_symmetric_configs |
| | self.n_renders = n_random_configs + n_symmetric_configs |
| |
|
| | def generate_scenes(self): |
| | return generate_random_scenes(self.n_random_configs) + generate_specular_scenes(self.n_symmetric_configs) |
| |
|
| | def multiview_render(self, y, x): |
| | X_renders, Y_renders = [], [] |
| |
|
| | |
| | |
| |
|
| | x_svBRDFs = zip(x.normals, x.albedo, x.roughness) |
| | y_svBRDFs = zip(y.normals, y.albedo, y.roughness) |
| | for x_svBRDF, y_svBRDF in zip(x_svBRDFs, y_svBRDFs): |
| | x_renders, y_renders = [], [] |
| | for scene in self.generate_scenes(): |
| | x_renders.append(self.renderer.render(scene, x_svBRDF)) |
| | y_renders.append(self.renderer.render(scene, y_svBRDF)) |
| | X_renders.append(torch.cat(x_renders)) |
| | Y_renders.append(torch.cat(y_renders)) |
| |
|
| | out = torch.stack(X_renders), torch.stack(Y_renders) |
| | return out |
| |
|
| | def reconstruction(self, y, theta): |
| | views = [] |
| | |
| | for *svBRDF, t in zip(y.normals, y.albedo, y.roughness, theta): |
| | render = self.renderer.render(Scene.load(t), svBRDF) |
| | views.append(render) |
| | return torch.cat(views) |
| |
|
| | def __call__(self, y, x, **kargs): |
| | loss = F.l1_loss(torch.log(y + self.eps), torch.log(x + self.eps), **kargs) |
| | return loss |
| |
|
| | class DenseReg(nn.Module): |
| | def __init__( |
| | self, |
| | reg_weight: float, |
| | render_weight: float, |
| | pl_reg_weight: float = 0., |
| | pl_render_weight: float = 0., |
| | use_source: bool = True, |
| | use_target: bool = True, |
| | n_random_configs= 3, |
| | n_symmetric_configs = 6, |
| | ): |
| | super().__init__() |
| |
|
| | self.weights = [('albedo', reg_weight, self.log_l1), |
| | ('roughness', reg_weight, self.log_l1), |
| | ('normals', reg_weight, F.l1_loss)] |
| |
|
| | self.reg_weight = reg_weight |
| | self.render_weight = render_weight |
| | self.pl_reg_weight = pl_reg_weight |
| | self.pl_render_weight = pl_render_weight |
| | self.use_source = use_source |
| | self.use_target = use_target |
| |
|
| | self.renderer = Renderer() |
| | self.n_random_configs = n_random_configs |
| | self.n_symmetric_configs = n_symmetric_configs |
| | self.loss = RenderingLoss(self.renderer, n_random_configs=n_random_configs, n_symmetric_configs=n_symmetric_configs) |
| |
|
| | def log_l1(self, x, y, **kwargs): |
| | return F.l1_loss(torch.log(x + 0.01), torch.log(y + 0.01), **kwargs) |
| |
|
| | def forward(self, x, y): |
| | loss = EasyDict() |
| | x_src, x_tgt = x |
| | y_src, y_tgt = y |
| |
|
| | if self.use_source: |
| | |
| | for k, w, loss_fn in self.weights: |
| | loss[k] = w*loss_fn(y_src[k], x_src[k]) |
| |
|
| | |
| | x_src.image, y_src.image = self.loss.multiview_render(y_src, x_src) |
| | loss.render = self.render_weight*self.loss(y_src.image, x_src.image) |
| |
|
| | |
| | y_src.reco = self.loss.reconstruction(y_src, x_src.input_params) |
| |
|
| | if self.use_target: |
| | for k, w, loss_fn in self.weights: |
| | loss[f'tgt_{k}'] = self.pl_reg_weight*loss_fn(y_tgt[k], x_tgt[k]) |
| |
|
| | |
| | y_tgt.image, x_tgt.image = self.loss.multiview_render(y_tgt, x_tgt) |
| | loss.sd_render = self.pl_render_weight*self.loss(y_tgt.image, x_tgt.image) |
| |
|
| | |
| | y_tgt.reco = self.loss.reconstruction(y_tgt, x_tgt.input_params) |
| |
|
| | loss.total = torch.stack(list(loss.values())).sum() |
| | return loss |
| |
|
| | @torch.no_grad() |
| | def test(self, x, y, batch_idx, epoch, dl_id): |
| | assert len(x.name) == 1 |
| | y.reco = self.loss.reconstruction(y, x.input_params) |
| | return EasyDict(total=0) |
| |
|
| | @torch.no_grad() |
| | def predict(self, x_tgt, y_tgt, batch_idx, split, epoch): |
| | assert len(x_tgt.name) == 1 |
| |
|
| | |
| | I = x_tgt.input[0] |
| | name = x_tgt.name[0] |
| |
|
| | |
| | N_pred = y_tgt.normals[0] |
| | A_pred = y_tgt.albedo[0] |
| | R_pred = y_tgt.roughness[0] |
| |
|
| | |
| | |
| |
|
| | |
| | |
| |
|
| | |
| | |
| |
|
| | return EasyDict(total=0) |
| |
|