| import os |
| import imageio |
| import torch |
| import wandb |
| import numpy as np |
| import pytorch_lightning as pl |
| import torch.nn.functional as F |
|
|
| from module.model_2d import Encoder, Decoder, DiagonalGaussianDistribution, Encoder_GroupConv, Decoder_GroupConv, Encoder_GroupConv_LateFusion, Decoder_GroupConv_LateFusion |
| from utility.initialize import instantiate_from_config |
| from utility.triplane_renderer.renderer import get_embedder, NeRF, run_network, render_path1, to8b, img2mse, mse2psnr |
| from utility.triplane_renderer.eg3d_renderer import Renderer_TriPlane |
|
|
| class AutoencoderKL(pl.LightningModule): |
| def __init__(self, |
| ddconfig, |
| lossconfig, |
| embed_dim, |
| learning_rate=1e-3, |
| ckpt_path=None, |
| ignore_keys=[], |
| colorize_nlabels=None, |
| monitor=None, |
| decoder_ckpt=None, |
| norm=False, |
| renderer_type='nerf', |
| renderer_config=dict( |
| rgbnet_dim=18, |
| rgbnet_width=128, |
| viewpe=0, |
| feape=0 |
| ), |
| ): |
| super().__init__() |
| self.save_hyperparameters() |
| self.norm = norm |
| self.renderer_config = renderer_config |
| self.learning_rate = learning_rate |
| self.encoder = Encoder(**ddconfig) |
| self.decoder = Decoder(**ddconfig) |
| |
| self.lossconfig = lossconfig |
| assert ddconfig["double_z"] |
| self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1) |
| self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1) |
|
|
| self.embed_dim = embed_dim |
| if colorize_nlabels is not None: |
| assert type(colorize_nlabels)==int |
| self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1)) |
| if monitor is not None: |
| self.monitor = monitor |
| if ckpt_path is not None: |
| self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) |
|
|
| self.decoder_ckpt = decoder_ckpt |
| self.renderer_type = renderer_type |
| |
| assert self.renderer_type in ['nerf', 'eg3d'] |
| if self.renderer_type == 'nerf': |
| self.triplane_decoder, self.triplane_render_kwargs = self.create_nerf(decoder_ckpt) |
| elif self.renderer_type == 'eg3d': |
| self.triplane_decoder, self.triplane_render_kwargs = self.create_eg3d_decoder(decoder_ckpt) |
| else: |
| raise NotImplementedError |
|
|
| self.psum = torch.zeros([1]) |
| self.psum_sq = torch.zeros([1]) |
| self.psum_min = torch.zeros([1]) |
| self.psum_max = torch.zeros([1]) |
| self.count = 0 |
| self.len_dset = 0 |
| self.latent_list = [] |
|
|
| def init_from_ckpt(self, path, ignore_keys=list()): |
| sd = torch.load(path, map_location="cpu")["state_dict"] |
| keys = list(sd.keys()) |
| for k in keys: |
| for ik in ignore_keys: |
| if k.startswith(ik): |
| print("Deleting key {} from state_dict.".format(k)) |
| del sd[k] |
| self.load_state_dict(sd, strict=False) |
| print(f"Restored from {path}") |
|
|
| def encode(self, x, rollout=False): |
| h = self.encoder(x) |
| moments = self.quant_conv(h) |
| posterior = DiagonalGaussianDistribution(moments) |
| return posterior |
|
|
| def decode(self, z, unrollout=False): |
| z = self.post_quant_conv(z) |
| dec = self.decoder(z) |
| return dec |
|
|
| def forward(self, input, sample_posterior=True): |
| posterior = self.encode(input) |
| if sample_posterior: |
| z = posterior.sample() |
| else: |
| z = posterior.mode() |
| dec = self.decode(z) |
| return dec, posterior |
|
|
| def unrollout(self, *args, **kwargs): |
| pass |
|
|
| def loss(self, inputs, reconstructions, posteriors, prefix, batch=None): |
| reconstructions = reconstructions.contiguous() |
| rec_loss = torch.abs(inputs.contiguous() - reconstructions) |
| rec_loss = torch.sum(rec_loss) / rec_loss.shape[0] |
| kl_loss = posteriors.kl() |
| kl_loss = torch.sum(kl_loss) / kl_loss.shape[0] |
| loss = self.lossconfig.rec_weight * rec_loss + self.lossconfig.kl_weight * kl_loss |
|
|
| ret_dict = { |
| prefix+'mean_rec_loss': torch.abs(inputs.contiguous() - reconstructions.contiguous()).mean().detach(), |
| prefix+'rec_loss': rec_loss, |
| prefix+'kl_loss': kl_loss, |
| prefix+'loss': loss, |
| prefix+'mean': posteriors.mean.mean(), |
| prefix+'logvar': posteriors.logvar.mean(), |
| } |
|
|
| render_weight = self.lossconfig.get("render_weight", 0) |
| tv_weight = self.lossconfig.get("tv_weight", 0) |
| l1_weight = self.lossconfig.get("l1_weight", 0) |
| latent_tv_weight = self.lossconfig.get("latent_tv_weight", 0) |
| latent_l1_weight = self.lossconfig.get("latent_l1_weight", 0) |
|
|
| triplane_rec = self.unrollout(reconstructions) |
| if render_weight > 0 and batch is not None: |
| rgb_rendered, target = self.render_triplane_eg3d_decoder_sample_pixel(triplane_rec, batch['batch_rays'], batch['img']) |
| render_loss = ((rgb_rendered - target) ** 2).sum() / rgb_rendered.shape[0] * 256 |
| loss += render_weight * render_loss |
| ret_dict[prefix + 'render_loss'] = render_loss |
| if tv_weight > 0: |
| tvloss_y = torch.abs(triplane_rec[:, :, :-1] - triplane_rec[:, :, 1:]).sum() / triplane_rec.shape[0] |
| tvloss_x = torch.abs(triplane_rec[:, :, :, :-1] - triplane_rec[:, :, :, 1:]).sum() / triplane_rec.shape[0] |
| tvloss = tvloss_y + tvloss_x |
| loss += tv_weight * tvloss |
| ret_dict[prefix + 'tv_loss'] = tvloss |
| if l1_weight > 0: |
| l1 = (triplane_rec ** 2).sum() / triplane_rec.shape[0] |
| loss += l1_weight * l1 |
| ret_dict[prefix + 'l1_loss'] = l1 |
| if latent_tv_weight > 0: |
| latent = posteriors.mean |
| latent_tv_y = torch.abs(latent[:, :, :-1] - latent[:, :, 1:]).sum() / latent.shape[0] |
| latent_tv_x = torch.abs(latent[:, :, :, :-1] - latent[:, :, :, 1:]).sum() / latent.shape[0] |
| latent_tv_loss = latent_tv_y + latent_tv_x |
| loss += latent_tv_loss * latent_tv_weight |
| ret_dict[prefix + 'latent_tv_loss'] = latent_tv_loss |
| ret_dict[prefix + 'latent_max'] = latent.max() |
| ret_dict[prefix + 'latent_min'] = latent.min() |
| if latent_l1_weight > 0: |
| latent = posteriors.mean |
| latent_l1_loss = (latent ** 2).sum() / latent.shape[0] |
| loss += latent_l1_loss * latent_l1_weight |
| ret_dict[prefix + 'latent_l1_loss'] = latent_l1_loss |
|
|
| return loss, ret_dict |
|
|
| def training_step(self, batch, batch_idx): |
| |
| inputs = batch['triplane'] |
| reconstructions, posterior = self(inputs) |
|
|
| |
| |
| aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, prefix='train/') |
| |
| self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False) |
| return aeloss |
|
|
| |
| |
| |
| |
|
|
| |
| |
| |
|
|
| def validation_step(self, batch, batch_idx): |
| |
| |
| |
| |
|
|
| |
| |
|
|
| |
| |
| |
| |
|
|
| inputs = batch['triplane'] |
| reconstructions, posterior = self(inputs, sample_posterior=False) |
| aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, prefix='val/') |
| self.log_dict(log_dict_ae) |
|
|
| assert not self.norm |
| psnr_list = [] |
| psnr_input_list = [] |
| psnr_rec_list = [] |
| batch_size = inputs.shape[0] |
| for b in range(batch_size): |
| if self.renderer_type == 'nerf': |
| rgb_input, cur_psnr_list_input = self.render_triplane( |
| batch['triplane_ori'][b:b+1], batch['batch_rays'][b], batch['img_flat'][b], |
| batch['near'][b].unsqueeze(-1), batch['far'][b].unsqueeze(-1) |
| ) |
| rgb, cur_psnr_list = self.render_triplane( |
| reconstructions[b:b+1], batch['batch_rays'][b], batch['img_flat'][b], |
| batch['near'][b].unsqueeze(-1), batch['far'][b].unsqueeze(-1) |
| ) |
| elif self.renderer_type == 'eg3d': |
| rgb_input, cur_psnr_list_input = self.render_triplane_eg3d_decoder( |
| batch['triplane_ori'][b:b+1], batch['batch_rays'][b], batch['img'][b], |
| ) |
| rgb, cur_psnr_list = self.render_triplane_eg3d_decoder( |
| reconstructions[b:b+1], batch['batch_rays'][b], batch['img'][b], |
| ) |
| else: |
| raise NotImplementedError |
|
|
| cur_psnr_list_rec = [] |
| for i in range(rgb.shape[0]): |
| cur_psnr_list_rec.append(mse2psnr(img2mse(rgb_input[i], rgb[i]))) |
|
|
| rgb_input = to8b(rgb_input.detach().cpu().numpy()) |
| rgb_gt = to8b(batch['img'][b].detach().cpu().numpy()) |
| rgb = to8b(rgb.detach().cpu().numpy()) |
| |
| if b % 4 == 0 and batch_idx < 10: |
| rgb_all = np.concatenate([rgb_gt[1], rgb_input[1], rgb[1]], 1) |
| self.logger.experiment.log({ |
| "val/vis": [wandb.Image(rgb_all)] |
| }) |
|
|
| psnr_list += cur_psnr_list |
| psnr_input_list += cur_psnr_list_input |
| psnr_rec_list += cur_psnr_list_rec |
|
|
| self.log("val/psnr_input_gt", torch.Tensor(psnr_input_list).mean(), prog_bar=True) |
| self.log("val/psnr_input_rec", torch.Tensor(psnr_rec_list).mean(), prog_bar=True) |
| self.log("val/psnr_rec_gt", torch.Tensor(psnr_list).mean(), prog_bar=True) |
|
|
| return self.log_dict |
|
|
| def create_eg3d_decoder(self, decoder_ckpt): |
| triplane_decoder = Renderer_TriPlane(**self.renderer_config) |
| if decoder_ckpt is not None: |
| pretrain_pth = torch.load(decoder_ckpt, map_location='cpu') |
| pretrain_pth = { |
| '.'.join(k.split('.')[1:]): v for k, v in pretrain_pth.items() |
| } |
| triplane_decoder.load_state_dict(pretrain_pth) |
| render_kwargs = { |
| 'depth_resolution': 128, |
| 'disparity_space_sampling': False, |
| 'box_warp': 2.4, |
| 'depth_resolution_importance': 128, |
| 'clamp_mode': 'softplus', |
| 'white_back': True, |
| 'det': True |
| } |
| return triplane_decoder, render_kwargs |
|
|
| def render_triplane_eg3d_decoder(self, triplane, batch_rays, target): |
| ray_o = batch_rays[:, 0] |
| ray_d = batch_rays[:, 1] |
| psnr_list = [] |
| rec_img_list = [] |
| res = triplane.shape[-2] |
| for i in range(ray_o.shape[0]): |
| with torch.no_grad(): |
| render_out = self.triplane_decoder(triplane.reshape(1, 3, -1, res, res), |
| ray_o[i:i+1], ray_d[i:i+1], self.triplane_render_kwargs, whole_img=True, tvloss=False) |
| rec_img = render_out['rgb_marched'].permute(0, 2, 3, 1) |
| psnr = mse2psnr(img2mse(rec_img[0], target[i])) |
| psnr_list.append(psnr) |
| rec_img_list.append(rec_img) |
| return torch.cat(rec_img_list, 0), psnr_list |
|
|
| def render_triplane_eg3d_decoder_sample_pixel(self, triplane, batch_rays, target, sample_num=1024): |
| assert batch_rays.shape[1] == 1 |
| sel = torch.randint(batch_rays.shape[-2], [sample_num]) |
| ray_o = batch_rays[:, 0, 0, sel] |
| ray_d = batch_rays[:, 0, 1, sel] |
| res = triplane.shape[-2] |
| render_out = self.triplane_decoder(triplane.reshape(triplane.shape[0], 3, -1, res, res), |
| ray_o, ray_d, self.triplane_render_kwargs, whole_img=False, tvloss=False) |
| rec_img = render_out['rgb_marched'] |
| target = target.reshape(triplane.shape[0], -1, 3)[:, sel, :] |
| return rec_img, target |
|
|
| def create_nerf(self, decoder_ckpt): |
| |
|
|
| multires = 10 |
| netchunk = 1024*64 |
| i_embed = 0 |
| perturb = 0 |
| raw_noise_std = 0 |
|
|
| triplanechannel=18 |
| triplanesize=256 |
| chunk=4096 |
| num_instance=1 |
| batch_size=1 |
| use_viewdirs = True |
| white_bkgd = False |
| lrate_decay = 6 |
| netdepth=1 |
| netwidth=64 |
| N_samples = 512 |
| N_importance = 0 |
| N_rand = 8192 |
| multires_views=10 |
| precrop_iters = 0 |
| precrop_frac = 0.5 |
| i_weights=3000 |
|
|
| embed_fn, input_ch = get_embedder(multires, i_embed) |
| embeddirs_fn, input_ch_views = get_embedder(multires_views, i_embed) |
| output_ch = 4 |
| skips = [4] |
| model = NeRF(D=netdepth, W=netwidth, |
| input_ch=triplanechannel, size=triplanesize,output_ch=output_ch, skips=skips, |
| input_ch_views=input_ch_views, use_viewdirs=use_viewdirs, num_instance=num_instance) |
| |
| network_query_fn = lambda inputs, viewdirs, label,network_fn : \ |
| run_network(inputs, viewdirs, network_fn, |
| embed_fn=embed_fn, |
| embeddirs_fn=embeddirs_fn,label=label, |
| netchunk=netchunk) |
|
|
| ckpt = torch.load(decoder_ckpt) |
| model.load_state_dict(ckpt['network_fn_state_dict']) |
|
|
| render_kwargs_test = { |
| 'network_query_fn' : network_query_fn, |
| 'perturb' : perturb, |
| 'N_samples' : N_samples, |
| |
| 'use_viewdirs' : use_viewdirs, |
| 'white_bkgd' : white_bkgd, |
| 'raw_noise_std' : raw_noise_std, |
| } |
| render_kwargs_test['ndc'] = False |
| render_kwargs_test['lindisp'] = False |
| render_kwargs_test['perturb'] = False |
| render_kwargs_test['raw_noise_std'] = 0. |
|
|
| return model, render_kwargs_test |
|
|
| def render_triplane(self, triplane, batch_rays, target, near, far, chunk=4096): |
| self.triplane_decoder.tri_planes.copy_(triplane.detach()) |
| self.triplane_render_kwargs['network_fn'] = self.triplane_decoder |
| |
| |
| |
| |
| |
| with torch.no_grad(): |
| rgb, _, _, psnr_list = \ |
| render_path1(batch_rays, chunk, self.triplane_render_kwargs, gt_imgs=target, |
| near=near, far=far, label=torch.Tensor([0]).long().to(triplane.device)) |
| return rgb, psnr_list |
|
|
| def to_rgb(self, plane): |
| x = plane.float() |
| if not hasattr(self, "colorize"): |
| self.colorize = torch.randn(3, x.shape[1], 1, 1).to(x) |
| x = torch.nn.functional.conv2d(x, weight=self.colorize) |
| x = ((x - x.min()) / (x.max() - x.min()) * 255.).permute(0, 2, 3, 1).detach().cpu().numpy().astype(np.uint8) |
| return x |
|
|
| def to_rgb_triplane(self, plane): |
| x = plane.float() |
| if not hasattr(self, "colorize_triplane"): |
| self.colorize_triplane = torch.randn(3, x.shape[1], 1, 1).to(x) |
| x = torch.nn.functional.conv2d(x, weight=self.colorize_triplane) |
| x = ((x - x.min()) / (x.max() - x.min()) * 255.).permute(0, 2, 3, 1).detach().cpu().numpy().astype(np.uint8) |
| return x |
|
|
| def test_step(self, batch, batch_idx): |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
|
|
| |
| |
| |
|
|
| inputs = batch['triplane'] |
| reconstructions, posterior = self(inputs, sample_posterior=False) |
| aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, prefix='test/', batch=None) |
| self.log_dict(log_dict_ae) |
|
|
| batch_size = inputs.shape[0] |
| psnr_list = [] |
| psnr_input_list = [] |
| psnr_rec_list = [] |
|
|
| z = posterior.mode() |
| colorize_z = self.to_rgb(z)[0] |
| colorize_triplane_input = self.to_rgb_triplane(inputs)[0] |
| colorize_triplane_output = self.to_rgb_triplane(reconstructions)[0] |
| |
| |
| |
| |
| if batch_idx < 10: |
| imageio.imwrite(os.path.join(self.logger.log_dir, "colorize_z_{}.png".format(batch_idx)), colorize_z) |
| imageio.imwrite(os.path.join(self.logger.log_dir, "colorize_input_{}.png".format(batch_idx)), colorize_triplane_input) |
| imageio.imwrite(os.path.join(self.logger.log_dir, "colorize_output_{}.png".format(batch_idx)), colorize_triplane_output) |
| |
| |
| |
|
|
| np_z = z.detach().cpu().numpy() |
| |
| |
|
|
| self.latent_list.append(np_z) |
|
|
| if self.psum.device != z.device: |
| self.psum = self.psum.to(z.device) |
| self.psum_sq = self.psum_sq.to(z.device) |
| self.psum_min = self.psum_min.to(z.device) |
| self.psum_max = self.psum_max.to(z.device) |
| self.psum += z.sum() |
| self.psum_sq += (z ** 2).sum() |
| self.psum_min += z.reshape(-1).min(-1)[0] |
| self.psum_max += z.reshape(-1).max(-1)[0] |
| assert len(z.shape) == 4 |
| self.count += z.shape[0] * z.shape[1] * z.shape[2] * z.shape[3] |
| self.len_dset += 1 |
|
|
| if self.norm: |
| assert NotImplementedError |
| else: |
| reconstructions_unnormalize = reconstructions |
|
|
| for b in range(batch_size): |
| if self.renderer_type == 'nerf': |
| rgb_input, cur_psnr_list_input = self.render_triplane( |
| batch['triplane_ori'][b:b+1], batch['batch_rays'][b], batch['img_flat'][b], |
| batch['near'][b].unsqueeze(-1), batch['far'][b].unsqueeze(-1) |
| ) |
| rgb, cur_psnr_list = self.render_triplane( |
| reconstructions_unnormalize[b:b+1], batch['batch_rays'][b], batch['img_flat'][b], |
| batch['near'][b].unsqueeze(-1), batch['far'][b].unsqueeze(-1) |
| ) |
| elif self.renderer_type == 'eg3d': |
| rgb_input, cur_psnr_list_input = self.render_triplane_eg3d_decoder( |
| batch['triplane_ori'][b:b+1], batch['batch_rays'][b], batch['img'][b], |
| ) |
| rgb, cur_psnr_list = self.render_triplane_eg3d_decoder( |
| reconstructions_unnormalize[b:b+1], batch['batch_rays'][b], batch['img'][b], |
| ) |
| else: |
| raise NotImplementedError |
|
|
| cur_psnr_list_rec = [] |
| for i in range(rgb.shape[0]): |
| cur_psnr_list_rec.append(mse2psnr(img2mse(rgb_input[i], rgb[i]))) |
|
|
| rgb_input = to8b(rgb_input.detach().cpu().numpy()) |
| rgb_gt = to8b(batch['img'][b].detach().cpu().numpy()) |
| rgb = to8b(rgb.detach().cpu().numpy()) |
| |
| if batch_idx < 10: |
| imageio.imwrite(os.path.join(self.logger.log_dir, "{}_{}_input.png".format(batch_idx, b)), rgb_input[1]) |
| imageio.imwrite(os.path.join(self.logger.log_dir, "{}_{}_rec.png".format(batch_idx, b)), rgb[1]) |
| imageio.imwrite(os.path.join(self.logger.log_dir, "{}_{}_gt.png".format(batch_idx, b)), rgb_gt[1]) |
|
|
| psnr_list += cur_psnr_list |
| psnr_input_list += cur_psnr_list_input |
| psnr_rec_list += cur_psnr_list_rec |
|
|
| self.log("test/psnr_input_gt", torch.Tensor(psnr_input_list).mean(), prog_bar=True) |
| self.log("test/psnr_input_rec", torch.Tensor(psnr_rec_list).mean(), prog_bar=True) |
| self.log("test/psnr_rec_gt", torch.Tensor(psnr_list).mean(), prog_bar=True) |
|
|
| def configure_optimizers(self): |
| lr = self.learning_rate |
| opt_ae = torch.optim.Adam(list(self.encoder.parameters())+ |
| list(self.decoder.parameters())+ |
| list(self.quant_conv.parameters())+ |
| list(self.post_quant_conv.parameters()), |
| lr=lr, betas=(0.5, 0.9)) |
| |
| |
| |
| return opt_ae |
|
|
| def on_test_epoch_end(self): |
| mean = self.psum / self.count |
| mean_min = self.psum_min / self.len_dset |
| mean_max = self.psum_max / self.len_dset |
| var = (self.psum_sq / self.count) - (mean ** 2) |
| std = torch.sqrt(var) |
|
|
| print("mean min: {}".format(mean_min)) |
| print("mean max: {}".format(mean_max)) |
| print("mean: {}".format(mean)) |
| print("std: {}".format(std)) |
|
|
| latent = np.concatenate(self.latent_list) |
| q75, q25 = np.percentile(latent.reshape(-1), [75 ,25]) |
| median = np.median(latent.reshape(-1)) |
| iqr = q75 - q25 |
| norm_iqr = iqr * 0.7413 |
| print("Norm IQR: {}".format(norm_iqr)) |
| print("Inverse Norm IQR: {}".format(1/norm_iqr)) |
| print("Median: {}".format(median)) |
|
|
|
|
| class AutoencoderKLRollOut(AutoencoderKL): |
| def __init__(self, *args, **kwargs): |
| super().__init__(*args, **kwargs) |
| self.psum = torch.zeros([1]) |
| self.psum_sq = torch.zeros([1]) |
| self.psum_min = torch.zeros([1]) |
| self.psum_max = torch.zeros([1]) |
| self.count = 0 |
| self.len_dset = 0 |
|
|
| def rollout(self, triplane): |
| res = triplane.shape[-1] |
| ch = triplane.shape[1] |
| triplane = triplane.reshape(-1, 3, ch//3, res, res).permute(0, 2, 3, 1, 4).reshape(-1, ch//3, res, 3 * res) |
| return triplane |
|
|
| def unrollout(self, triplane): |
| res = triplane.shape[-2] |
| ch = 3 * triplane.shape[1] |
| triplane = triplane.reshape(-1, ch//3, res, 3, res).permute(0, 3, 1, 2, 4).reshape(-1, ch, res, res) |
| return triplane |
|
|
| def encode(self, x, rollout=False): |
| if rollout: |
| x = self.rollout(x) |
| h = self.encoder(x) |
| moments = self.quant_conv(h) |
| posterior = DiagonalGaussianDistribution(moments) |
| return posterior |
|
|
| def decode(self, z, unrollout=False): |
| z = self.post_quant_conv(z) |
| dec = self.decoder(z) |
| if unrollout: |
| dec = self.unrollout(dec) |
| return dec |
|
|
| def forward(self, input, sample_posterior=True): |
| posterior = self.encode(input) |
| if sample_posterior: |
| z = posterior.sample() |
| else: |
| z = posterior.mode() |
| dec = self.decode(z) |
| return dec, posterior |
|
|
| def training_step(self, batch, batch_idx): |
| inputs = self.rollout(batch['triplane']) |
| reconstructions, posterior = self(inputs) |
| aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, prefix='train/') |
| self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False) |
| return aeloss |
|
|
| def validation_step(self, batch, batch_idx): |
| inputs = self.rollout(batch['triplane']) |
| reconstructions, posterior = self(inputs, sample_posterior=False) |
| aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, prefix='val/') |
| self.log_dict(log_dict_ae) |
|
|
| assert not self.norm |
| reconstructions = self.unrollout(reconstructions) |
| psnr_list = [] |
| psnr_input_list = [] |
| psnr_rec_list = [] |
| batch_size = inputs.shape[0] |
| for b in range(batch_size): |
| if self.renderer_type == 'nerf': |
| rgb_input, cur_psnr_list_input = self.render_triplane( |
| batch['triplane_ori'][b:b+1], batch['batch_rays'][b], batch['img_flat'][b], |
| batch['near'][b].unsqueeze(-1), batch['far'][b].unsqueeze(-1) |
| ) |
| rgb, cur_psnr_list = self.render_triplane( |
| reconstructions[b:b+1], batch['batch_rays'][b], batch['img_flat'][b], |
| batch['near'][b].unsqueeze(-1), batch['far'][b].unsqueeze(-1) |
| ) |
| elif self.renderer_type == 'eg3d': |
| rgb_input, cur_psnr_list_input = self.render_triplane_eg3d_decoder( |
| batch['triplane_ori'][b:b+1], batch['batch_rays'][b], batch['img'][b], |
| ) |
| rgb, cur_psnr_list = self.render_triplane_eg3d_decoder( |
| reconstructions[b:b+1], batch['batch_rays'][b], batch['img'][b], |
| ) |
| else: |
| raise NotImplementedError |
|
|
| cur_psnr_list_rec = [] |
| for i in range(rgb.shape[0]): |
| cur_psnr_list_rec.append(mse2psnr(img2mse(rgb_input[i], rgb[i]))) |
|
|
| rgb_input = to8b(rgb_input.detach().cpu().numpy()) |
| rgb_gt = to8b(batch['img'][b].detach().cpu().numpy()) |
| rgb = to8b(rgb.detach().cpu().numpy()) |
| |
| if b % 4 == 0 and batch_idx < 10: |
| rgb_all = np.concatenate([rgb_gt[1], rgb_input[1], rgb[1]], 1) |
| self.logger.experiment.log({ |
| "val/vis": [wandb.Image(rgb_all)] |
| }) |
|
|
| psnr_list += cur_psnr_list |
| psnr_input_list += cur_psnr_list_input |
| psnr_rec_list += cur_psnr_list_rec |
|
|
| self.log("val/psnr_input_gt", torch.Tensor(psnr_input_list).mean(), prog_bar=True) |
| self.log("val/psnr_input_rec", torch.Tensor(psnr_rec_list).mean(), prog_bar=True) |
| self.log("val/psnr_rec_gt", torch.Tensor(psnr_list).mean(), prog_bar=True) |
|
|
| return self.log_dict |
|
|
| def to_rgb(self, plane): |
| x = plane.float() |
| if not hasattr(self, "colorize"): |
| self.colorize = torch.randn(3, x.shape[1], 1, 1).to(x) |
| x = torch.nn.functional.conv2d(x, weight=self.colorize) |
| x = ((x - x.min()) / (x.max() - x.min()) * 255.).permute(0, 2, 3, 1).detach().cpu().numpy().astype(np.uint8) |
| return x |
|
|
| def to_rgb_triplane(self, plane): |
| x = plane.float() |
| if not hasattr(self, "colorize_triplane"): |
| self.colorize_triplane = torch.randn(3, x.shape[1], 1, 1).to(x) |
| x = torch.nn.functional.conv2d(x, weight=self.colorize_triplane) |
| x = ((x - x.min()) / (x.max() - x.min()) * 255.).permute(0, 2, 3, 1).detach().cpu().numpy().astype(np.uint8) |
| return x |
|
|
| def test_step(self, batch, batch_idx): |
| inputs = self.rollout(batch['triplane']) |
| reconstructions, posterior = self(inputs, sample_posterior=False) |
| aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, prefix='test/') |
| self.log_dict(log_dict_ae) |
|
|
| batch_size = inputs.shape[0] |
| psnr_list = [] |
| psnr_input_list = [] |
| psnr_rec_list = [] |
|
|
| z = posterior.mode() |
| colorize_z = self.to_rgb(z)[0] |
| colorize_triplane_input = self.to_rgb_triplane(inputs)[0] |
| colorize_triplane_output = self.to_rgb_triplane(reconstructions)[0] |
| |
| |
| |
| |
|
|
| reconstructions = self.unrollout(reconstructions) |
|
|
| if self.psum.device != z.device: |
| self.psum = self.psum.to(z.device) |
| self.psum_sq = self.psum_sq.to(z.device) |
| self.psum_min = self.psum_min.to(z.device) |
| self.psum_max = self.psum_max.to(z.device) |
| self.psum += z.sum() |
| self.psum_sq += (z ** 2).sum() |
| self.psum_min += z.reshape(-1).min(-1)[0] |
| self.psum_max += z.reshape(-1).max(-1)[0] |
| assert len(z.shape) == 4 |
| self.count += z.shape[0] * z.shape[1] * z.shape[2] * z.shape[3] |
| self.len_dset += 1 |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| mean = torch.Tensor([ |
| -1.8449, -1.8242, 0.9667, -1.0187, 1.0647, -0.5422, -1.8632, -1.8435, |
| 0.9314, -1.0261, 1.0356, -0.5484, -1.8543, -1.8348, 0.9109, -1.0169, |
| 1.0160, -0.5467 |
| ]).reshape(1, 18, 1, 1).to(inputs.device) |
| std = torch.Tensor([ |
| 1.7593, 1.6127, 2.7132, 1.5500, 2.7893, 0.7707, 2.1114, 1.9198, 2.6586, |
| 1.8021, 2.5473, 1.0305, 1.7042, 1.7507, 2.4270, 1.4365, 2.2511, 0.8792 |
| ]).reshape(1, 18, 1, 1).to(inputs.device) |
|
|
| if self.norm: |
| reconstructions_unnormalize = reconstructions * std + mean |
| else: |
| reconstructions_unnormalize = reconstructions |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
|
|
| |
| |
| |
|
|
| def on_test_epoch_end(self): |
| mean = self.psum / self.count |
| mean_min = self.psum_min / self.len_dset |
| mean_max = self.psum_max / self.len_dset |
| var = (self.psum_sq / self.count) - (mean ** 2) |
| std = torch.sqrt(var) |
|
|
| print("mean min: {}".format(mean_min)) |
| print("mean max: {}".format(mean_max)) |
| print("mean: {}".format(mean)) |
| print("std: {}".format(std)) |
|
|
|
|
| class AutoencoderKLRollOut3DAware(AutoencoderKL): |
| def __init__(self, *args, **kwargs): |
| try: |
| ckpt_path = kwargs['ckpt_path'] |
| kwargs['ckpt_path'] = None |
| except: |
| ckpt_path = None |
| |
| super().__init__(*args, **kwargs) |
| self.psum = torch.zeros([1]) |
| self.psum_sq = torch.zeros([1]) |
| self.psum_min = torch.zeros([1]) |
| self.psum_max = torch.zeros([1]) |
| self.count = 0 |
| self.len_dset = 0 |
|
|
| ddconfig = kwargs['ddconfig'] |
| ddconfig['z_channels'] *= 3 |
| del self.decoder |
| del self.post_quant_conv |
| self.decoder = Decoder(**ddconfig) |
| self.post_quant_conv = torch.nn.Conv2d(kwargs['embed_dim'] * 3, ddconfig["z_channels"], 1) |
|
|
| if ckpt_path is not None: |
| self.init_from_ckpt(ckpt_path) |
|
|
| def rollout(self, triplane): |
| res = triplane.shape[-1] |
| ch = triplane.shape[1] |
| triplane = triplane.reshape(-1, 3, ch//3, res, res).permute(0, 2, 3, 1, 4).reshape(-1, ch//3, res, 3 * res) |
| return triplane |
|
|
| def to3daware(self, triplane): |
| res = triplane.shape[-2] |
| plane1 = triplane[..., :res] |
| plane2 = triplane[..., res:2*res] |
| plane3 = triplane[..., 2*res:3*res] |
|
|
| x_mp = torch.nn.MaxPool2d((res, 1)) |
| y_mp = torch.nn.MaxPool2d((1, res)) |
| x_mp_rep = lambda i: x_mp(i).repeat(1, 1, res, 1).permute(0, 1, 3, 2) |
| y_mp_rep = lambda i: y_mp(i).repeat(1, 1, 1, res).permute(0, 1, 3, 2) |
| |
| plane21 = x_mp_rep(plane2) |
| plane31 = torch.flip(y_mp_rep(plane3), (3,)) |
| new_plane1 = torch.cat([plane1, plane21, plane31], 1) |
| |
| plane12 = y_mp_rep(plane1) |
| plane32 = x_mp_rep(plane3) |
| new_plane2 = torch.cat([plane2, plane12, plane32], 1) |
| |
| plane13 = torch.flip(x_mp_rep(plane1), (2,)) |
| plane23 = y_mp_rep(plane2) |
| new_plane3 = torch.cat([plane3, plane13, plane23], 1) |
|
|
| new_plane = torch.cat([new_plane1, new_plane2, new_plane3], -1).contiguous() |
| return new_plane |
|
|
| def unrollout(self, triplane): |
| res = triplane.shape[-2] |
| ch = 3 * triplane.shape[1] |
| triplane = triplane.reshape(-1, ch//3, res, 3, res).permute(0, 3, 1, 2, 4).reshape(-1, ch, res, res) |
| return triplane |
|
|
| def encode(self, x, rollout=False): |
| if rollout: |
| x = self.to3daware(self.rollout(x)) |
| h = self.encoder(x) |
| moments = self.quant_conv(h) |
| posterior = DiagonalGaussianDistribution(moments) |
| return posterior |
|
|
| def decode(self, z, unrollout=False): |
| z = self.to3daware(z) |
| z = self.post_quant_conv(z) |
| dec = self.decoder(z) |
| if unrollout: |
| dec = self.unrollout(dec) |
| return dec |
|
|
| def forward(self, input, sample_posterior=True): |
| posterior = self.encode(input) |
| if sample_posterior: |
| z = posterior.sample() |
| else: |
| z = posterior.mode() |
| dec = self.decode(z) |
| return dec, posterior |
|
|
| def training_step(self, batch, batch_idx): |
| inputs = self.rollout(batch['triplane']) |
| reconstructions, posterior = self(self.to3daware(inputs)) |
| aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, prefix='train/', batch=batch) |
| self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False) |
| return aeloss |
|
|
| def validation_step(self, batch, batch_idx): |
| inputs = self.rollout(batch['triplane']) |
| reconstructions, posterior = self(self.to3daware(inputs), sample_posterior=False) |
| aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, prefix='val/', batch=None) |
| self.log_dict(log_dict_ae) |
|
|
| assert not self.norm |
| reconstructions = self.unrollout(reconstructions) |
| psnr_list = [] |
| psnr_input_list = [] |
| psnr_rec_list = [] |
| batch_size = inputs.shape[0] |
| for b in range(batch_size): |
| if self.renderer_type == 'nerf': |
| rgb_input, cur_psnr_list_input = self.render_triplane( |
| batch['triplane_ori'][b:b+1], batch['batch_rays'][b], batch['img_flat'][b], |
| batch['near'][b].unsqueeze(-1), batch['far'][b].unsqueeze(-1) |
| ) |
| rgb, cur_psnr_list = self.render_triplane( |
| reconstructions[b:b+1], batch['batch_rays'][b], batch['img_flat'][b], |
| batch['near'][b].unsqueeze(-1), batch['far'][b].unsqueeze(-1) |
| ) |
| elif self.renderer_type == 'eg3d': |
| rgb_input, cur_psnr_list_input = self.render_triplane_eg3d_decoder( |
| batch['triplane_ori'][b:b+1], batch['batch_rays'][b], batch['img'][b], |
| ) |
| rgb, cur_psnr_list = self.render_triplane_eg3d_decoder( |
| reconstructions[b:b+1], batch['batch_rays'][b], batch['img'][b], |
| ) |
| else: |
| raise NotImplementedError |
|
|
| cur_psnr_list_rec = [] |
| for i in range(rgb.shape[0]): |
| cur_psnr_list_rec.append(mse2psnr(img2mse(rgb_input[i], rgb[i]))) |
|
|
| rgb_input = to8b(rgb_input.detach().cpu().numpy()) |
| rgb_gt = to8b(batch['img'][b].detach().cpu().numpy()) |
| rgb = to8b(rgb.detach().cpu().numpy()) |
| |
| if b % 4 == 0 and batch_idx < 10: |
| rgb_all = np.concatenate([rgb_gt[1], rgb_input[1], rgb[1]], 1) |
| self.logger.experiment.log({ |
| "val/vis": [wandb.Image(rgb_all)] |
| }) |
|
|
| psnr_list += cur_psnr_list |
| psnr_input_list += cur_psnr_list_input |
| psnr_rec_list += cur_psnr_list_rec |
|
|
| self.log("val/psnr_input_gt", torch.Tensor(psnr_input_list).mean(), prog_bar=True) |
| self.log("val/psnr_input_rec", torch.Tensor(psnr_rec_list).mean(), prog_bar=True) |
| self.log("val/psnr_rec_gt", torch.Tensor(psnr_list).mean(), prog_bar=True) |
|
|
| return self.log_dict |
|
|
| def to_rgb(self, plane): |
| x = plane.float() |
| if not hasattr(self, "colorize"): |
| self.colorize = torch.randn(3, x.shape[1], 1, 1).to(x) |
| x = torch.nn.functional.conv2d(x, weight=self.colorize) |
| x = ((x - x.min()) / (x.max() - x.min()) * 255.).permute(0, 2, 3, 1).detach().cpu().numpy().astype(np.uint8) |
| return x |
|
|
| def to_rgb_triplane(self, plane): |
| x = plane.float() |
| if not hasattr(self, "colorize_triplane"): |
| self.colorize_triplane = torch.randn(3, x.shape[1], 1, 1).to(x) |
| x = torch.nn.functional.conv2d(x, weight=self.colorize_triplane) |
| x = ((x - x.min()) / (x.max() - x.min()) * 255.).permute(0, 2, 3, 1).detach().cpu().numpy().astype(np.uint8) |
| return x |
| |
| def to_rgb_3daware(self, plane): |
| x = plane.float() |
| if not hasattr(self, "colorize_3daware"): |
| self.colorize_3daware = torch.randn(3, x.shape[1], 1, 1).to(x) |
| x = torch.nn.functional.conv2d(x, weight=self.colorize_3daware) |
| x = ((x - x.min()) / (x.max() - x.min()) * 255.).permute(0, 2, 3, 1).detach().cpu().numpy().astype(np.uint8) |
| return x |
|
|
| def test_step(self, batch, batch_idx): |
| inputs = self.rollout(batch['triplane']) |
| reconstructions, posterior = self(self.to3daware(inputs), sample_posterior=False) |
| aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, prefix='test/', batch=None) |
| self.log_dict(log_dict_ae) |
|
|
| batch_size = inputs.shape[0] |
| psnr_list = [] |
| psnr_input_list = [] |
| psnr_rec_list = [] |
|
|
| z = posterior.mode() |
| colorize_z = self.to_rgb(z)[0] |
| colorize_triplane_input = self.to_rgb_triplane(inputs)[0] |
| colorize_triplane_output = self.to_rgb_triplane(reconstructions)[0] |
| colorize_triplane_rollout_3daware = self.to_rgb_3daware(self.to3daware(inputs))[0] |
| res = inputs.shape[1] |
| colorize_triplane_rollout_3daware_1 = self.to_rgb_triplane(self.to3daware(inputs)[:,res:2*res])[0] |
| colorize_triplane_rollout_3daware_2 = self.to_rgb_triplane(self.to3daware(inputs)[:,2*res:3*res])[0] |
| if batch_idx < 10: |
| imageio.imwrite(os.path.join(self.logger.log_dir, "colorize_z_{}.png".format(batch_idx)), colorize_z) |
| imageio.imwrite(os.path.join(self.logger.log_dir, "colorize_input_{}.png".format(batch_idx)), colorize_triplane_input) |
| imageio.imwrite(os.path.join(self.logger.log_dir, "colorize_output_{}.png".format(batch_idx)), colorize_triplane_output) |
| imageio.imwrite(os.path.join(self.logger.log_dir, "colorize_input_3daware_{}.png".format(batch_idx)), colorize_triplane_rollout_3daware) |
| imageio.imwrite(os.path.join(self.logger.log_dir, "colorize_input_3daware_{}_1.png".format(batch_idx)), colorize_triplane_rollout_3daware_1) |
| imageio.imwrite(os.path.join(self.logger.log_dir, "colorize_input_3daware_{}_2.png".format(batch_idx)), colorize_triplane_rollout_3daware_2) |
|
|
| reconstructions = self.unrollout(reconstructions) |
|
|
| if self.psum.device != z.device: |
| self.psum = self.psum.to(z.device) |
| self.psum_sq = self.psum_sq.to(z.device) |
| self.psum_min = self.psum_min.to(z.device) |
| self.psum_max = self.psum_max.to(z.device) |
| self.psum += z.sum() |
| self.psum_sq += (z ** 2).sum() |
| self.psum_min += z.reshape(-1).min(-1)[0] |
| self.psum_max += z.reshape(-1).max(-1)[0] |
| assert len(z.shape) == 4 |
| self.count += z.shape[0] * z.shape[1] * z.shape[2] * z.shape[3] |
| self.len_dset += 1 |
|
|
| if self.norm: |
| assert NotImplementedError |
| else: |
| reconstructions_unnormalize = reconstructions |
|
|
| for b in range(batch_size): |
| if self.renderer_type == 'nerf': |
| rgb_input, cur_psnr_list_input = self.render_triplane( |
| batch['triplane_ori'][b:b+1], batch['batch_rays'][b], batch['img_flat'][b], |
| batch['near'][b].unsqueeze(-1), batch['far'][b].unsqueeze(-1) |
| ) |
| rgb, cur_psnr_list = self.render_triplane( |
| reconstructions_unnormalize[b:b+1], batch['batch_rays'][b], batch['img_flat'][b], |
| batch['near'][b].unsqueeze(-1), batch['far'][b].unsqueeze(-1) |
| ) |
| elif self.renderer_type == 'eg3d': |
| rgb_input, cur_psnr_list_input = self.render_triplane_eg3d_decoder( |
| batch['triplane_ori'][b:b+1], batch['batch_rays'][b], batch['img'][b], |
| ) |
| rgb, cur_psnr_list = self.render_triplane_eg3d_decoder( |
| reconstructions_unnormalize[b:b+1], batch['batch_rays'][b], batch['img'][b], |
| ) |
| else: |
| raise NotImplementedError |
|
|
| cur_psnr_list_rec = [] |
| for i in range(rgb.shape[0]): |
| cur_psnr_list_rec.append(mse2psnr(img2mse(rgb_input[i], rgb[i]))) |
|
|
| rgb_input = to8b(rgb_input.detach().cpu().numpy()) |
| rgb_gt = to8b(batch['img'][b].detach().cpu().numpy()) |
| rgb = to8b(rgb.detach().cpu().numpy()) |
| |
| if batch_idx < 10: |
| imageio.imwrite(os.path.join(self.logger.log_dir, "{}_{}_input.png".format(batch_idx, b)), rgb_input[1]) |
| imageio.imwrite(os.path.join(self.logger.log_dir, "{}_{}_rec.png".format(batch_idx, b)), rgb[1]) |
| imageio.imwrite(os.path.join(self.logger.log_dir, "{}_{}_gt.png".format(batch_idx, b)), rgb_gt[1]) |
|
|
| psnr_list += cur_psnr_list |
| psnr_input_list += cur_psnr_list_input |
| psnr_rec_list += cur_psnr_list_rec |
|
|
| self.log("test/psnr_input_gt", torch.Tensor(psnr_input_list).mean(), prog_bar=True) |
| self.log("test/psnr_input_rec", torch.Tensor(psnr_rec_list).mean(), prog_bar=True) |
| self.log("test/psnr_rec_gt", torch.Tensor(psnr_list).mean(), prog_bar=True) |
|
|
| def on_test_epoch_end(self): |
| mean = self.psum / self.count |
| mean_min = self.psum_min / self.len_dset |
| mean_max = self.psum_max / self.len_dset |
| var = (self.psum_sq / self.count) - (mean ** 2) |
| std = torch.sqrt(var) |
|
|
| print("mean min: {}".format(mean_min)) |
| print("mean max: {}".format(mean_max)) |
| print("mean: {}".format(mean)) |
| print("std: {}".format(std)) |
|
|
|
|
| class AutoencoderKLRollOut3DAwareOnlyInput(AutoencoderKL): |
| def __init__(self, *args, **kwargs): |
| try: |
| ckpt_path = kwargs['ckpt_path'] |
| kwargs['ckpt_path'] = None |
| except: |
| ckpt_path = None |
| |
| super().__init__(*args, **kwargs) |
| self.psum = torch.zeros([1]) |
| self.psum_sq = torch.zeros([1]) |
| self.psum_min = torch.zeros([1]) |
| self.psum_max = torch.zeros([1]) |
| self.count = 0 |
| self.len_dset = 0 |
|
|
| |
| |
| |
| |
|
|
| if ckpt_path is not None: |
| self.init_from_ckpt(ckpt_path) |
|
|
| def rollout(self, triplane): |
| res = triplane.shape[-1] |
| ch = triplane.shape[1] |
| triplane = triplane.reshape(-1, 3, ch//3, res, res).permute(0, 2, 3, 1, 4).reshape(-1, ch//3, res, 3 * res) |
| return triplane |
|
|
| def to3daware(self, triplane): |
| res = triplane.shape[-2] |
| plane1 = triplane[..., :res] |
| plane2 = triplane[..., res:2*res] |
| plane3 = triplane[..., 2*res:3*res] |
|
|
| x_mp = torch.nn.MaxPool2d((res, 1)) |
| y_mp = torch.nn.MaxPool2d((1, res)) |
| x_mp_rep = lambda i: x_mp(i).repeat(1, 1, res, 1).permute(0, 1, 3, 2) |
| y_mp_rep = lambda i: y_mp(i).repeat(1, 1, 1, res).permute(0, 1, 3, 2) |
| |
| plane21 = x_mp_rep(plane2) |
| plane31 = torch.flip(y_mp_rep(plane3), (3,)) |
| new_plane1 = torch.cat([plane1, plane21, plane31], 1) |
| |
| plane12 = y_mp_rep(plane1) |
| plane32 = x_mp_rep(plane3) |
| new_plane2 = torch.cat([plane2, plane12, plane32], 1) |
| |
| plane13 = torch.flip(x_mp_rep(plane1), (2,)) |
| plane23 = y_mp_rep(plane2) |
| new_plane3 = torch.cat([plane3, plane13, plane23], 1) |
|
|
| new_plane = torch.cat([new_plane1, new_plane2, new_plane3], -1).contiguous() |
| return new_plane |
|
|
| def unrollout(self, triplane): |
| res = triplane.shape[-2] |
| ch = 3 * triplane.shape[1] |
| triplane = triplane.reshape(-1, ch//3, res, 3, res).permute(0, 3, 1, 2, 4).reshape(-1, ch, res, res) |
| return triplane |
|
|
| def encode(self, x, rollout=False): |
| if rollout: |
| x = self.to3daware(self.rollout(x)) |
| h = self.encoder(x) |
| moments = self.quant_conv(h) |
| posterior = DiagonalGaussianDistribution(moments) |
| return posterior |
|
|
| def decode(self, z, unrollout=False): |
| |
| z = self.post_quant_conv(z) |
| dec = self.decoder(z) |
| if unrollout: |
| dec = self.unrollout(dec) |
| return dec |
|
|
| def forward(self, input, sample_posterior=True): |
| posterior = self.encode(input) |
| if sample_posterior: |
| z = posterior.sample() |
| else: |
| z = posterior.mode() |
| dec = self.decode(z) |
| return dec, posterior |
|
|
| def training_step(self, batch, batch_idx): |
| inputs = self.rollout(batch['triplane']) |
| reconstructions, posterior = self(self.to3daware(inputs)) |
| aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, prefix='train/') |
| self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False) |
| return aeloss |
|
|
| def validation_step(self, batch, batch_idx): |
| inputs = self.rollout(batch['triplane']) |
| reconstructions, posterior = self(self.to3daware(inputs), sample_posterior=False) |
| aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, prefix='val/') |
| self.log_dict(log_dict_ae) |
|
|
| assert not self.norm |
| reconstructions = self.unrollout(reconstructions) |
| psnr_list = [] |
| psnr_input_list = [] |
| psnr_rec_list = [] |
| batch_size = inputs.shape[0] |
| for b in range(batch_size): |
| if self.renderer_type == 'nerf': |
| rgb_input, cur_psnr_list_input = self.render_triplane( |
| batch['triplane_ori'][b:b+1], batch['batch_rays'][b], batch['img_flat'][b], |
| batch['near'][b].unsqueeze(-1), batch['far'][b].unsqueeze(-1) |
| ) |
| rgb, cur_psnr_list = self.render_triplane( |
| reconstructions[b:b+1], batch['batch_rays'][b], batch['img_flat'][b], |
| batch['near'][b].unsqueeze(-1), batch['far'][b].unsqueeze(-1) |
| ) |
| elif self.renderer_type == 'eg3d': |
| rgb_input, cur_psnr_list_input = self.render_triplane_eg3d_decoder( |
| batch['triplane_ori'][b:b+1], batch['batch_rays'][b], batch['img'][b], |
| ) |
| rgb, cur_psnr_list = self.render_triplane_eg3d_decoder( |
| reconstructions[b:b+1], batch['batch_rays'][b], batch['img'][b], |
| ) |
| else: |
| raise NotImplementedError |
|
|
| cur_psnr_list_rec = [] |
| for i in range(rgb.shape[0]): |
| cur_psnr_list_rec.append(mse2psnr(img2mse(rgb_input[i], rgb[i]))) |
|
|
| rgb_input = to8b(rgb_input.detach().cpu().numpy()) |
| rgb_gt = to8b(batch['img'][b].detach().cpu().numpy()) |
| rgb = to8b(rgb.detach().cpu().numpy()) |
| |
| if b % 4 == 0 and batch_idx < 10: |
| rgb_all = np.concatenate([rgb_gt[1], rgb_input[1], rgb[1]], 1) |
| self.logger.experiment.log({ |
| "val/vis": [wandb.Image(rgb_all)] |
| }) |
|
|
| psnr_list += cur_psnr_list |
| psnr_input_list += cur_psnr_list_input |
| psnr_rec_list += cur_psnr_list_rec |
|
|
| self.log("val/psnr_input_gt", torch.Tensor(psnr_input_list).mean(), prog_bar=True) |
| self.log("val/psnr_input_rec", torch.Tensor(psnr_rec_list).mean(), prog_bar=True) |
| self.log("val/psnr_rec_gt", torch.Tensor(psnr_list).mean(), prog_bar=True) |
|
|
| return self.log_dict |
|
|
| def to_rgb(self, plane): |
| x = plane.float() |
| if not hasattr(self, "colorize"): |
| self.colorize = torch.randn(3, x.shape[1], 1, 1).to(x) |
| x = torch.nn.functional.conv2d(x, weight=self.colorize) |
| x = ((x - x.min()) / (x.max() - x.min()) * 255.).permute(0, 2, 3, 1).detach().cpu().numpy().astype(np.uint8) |
| return x |
|
|
| def to_rgb_triplane(self, plane): |
| x = plane.float() |
| if not hasattr(self, "colorize_triplane"): |
| self.colorize_triplane = torch.randn(3, x.shape[1], 1, 1).to(x) |
| x = torch.nn.functional.conv2d(x, weight=self.colorize_triplane) |
| x = ((x - x.min()) / (x.max() - x.min()) * 255.).permute(0, 2, 3, 1).detach().cpu().numpy().astype(np.uint8) |
| return x |
|
|
| def test_step(self, batch, batch_idx): |
| inputs = self.rollout(batch['triplane']) |
| reconstructions, posterior = self(self.to3daware(inputs), sample_posterior=False) |
| aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, prefix='test/') |
| self.log_dict(log_dict_ae) |
|
|
| batch_size = inputs.shape[0] |
| psnr_list = [] |
| psnr_input_list = [] |
| psnr_rec_list = [] |
|
|
| z = posterior.mode() |
| colorize_z = self.to_rgb(z)[0] |
| colorize_triplane_input = self.to_rgb_triplane(inputs)[0] |
| colorize_triplane_output = self.to_rgb_triplane(reconstructions)[0] |
| if batch_idx < 10: |
| imageio.imwrite(os.path.join(self.logger.log_dir, "colorize_z_{}.png".format(batch_idx)), colorize_z) |
| imageio.imwrite(os.path.join(self.logger.log_dir, "colorize_input_{}.png".format(batch_idx)), colorize_triplane_input) |
| imageio.imwrite(os.path.join(self.logger.log_dir, "colorize_output_{}.png".format(batch_idx)), colorize_triplane_output) |
|
|
| reconstructions = self.unrollout(reconstructions) |
|
|
| if self.psum.device != z.device: |
| self.psum = self.psum.to(z.device) |
| self.psum_sq = self.psum_sq.to(z.device) |
| self.psum_min = self.psum_min.to(z.device) |
| self.psum_max = self.psum_max.to(z.device) |
| self.psum += z.sum() |
| self.psum_sq += (z ** 2).sum() |
| self.psum_min += z.reshape(-1).min(-1)[0] |
| self.psum_max += z.reshape(-1).max(-1)[0] |
| assert len(z.shape) == 4 |
| self.count += z.shape[0] * z.shape[1] * z.shape[2] * z.shape[3] |
| self.len_dset += 1 |
|
|
| if self.norm: |
| assert NotImplementedError |
| else: |
| reconstructions_unnormalize = reconstructions |
|
|
| for b in range(batch_size): |
| if self.renderer_type == 'nerf': |
| rgb_input, cur_psnr_list_input = self.render_triplane( |
| batch['triplane_ori'][b:b+1], batch['batch_rays'][b], batch['img_flat'][b], |
| batch['near'][b].unsqueeze(-1), batch['far'][b].unsqueeze(-1) |
| ) |
| rgb, cur_psnr_list = self.render_triplane( |
| reconstructions_unnormalize[b:b+1], batch['batch_rays'][b], batch['img_flat'][b], |
| batch['near'][b].unsqueeze(-1), batch['far'][b].unsqueeze(-1) |
| ) |
| elif self.renderer_type == 'eg3d': |
| rgb_input, cur_psnr_list_input = self.render_triplane_eg3d_decoder( |
| batch['triplane_ori'][b:b+1], batch['batch_rays'][b], batch['img'][b], |
| ) |
| rgb, cur_psnr_list = self.render_triplane_eg3d_decoder( |
| reconstructions_unnormalize[b:b+1], batch['batch_rays'][b], batch['img'][b], |
| ) |
| else: |
| raise NotImplementedError |
|
|
| cur_psnr_list_rec = [] |
| for i in range(rgb.shape[0]): |
| cur_psnr_list_rec.append(mse2psnr(img2mse(rgb_input[i], rgb[i]))) |
|
|
| rgb_input = to8b(rgb_input.detach().cpu().numpy()) |
| rgb_gt = to8b(batch['img'][b].detach().cpu().numpy()) |
| rgb = to8b(rgb.detach().cpu().numpy()) |
| |
| if batch_idx < 10: |
| imageio.imwrite(os.path.join(self.logger.log_dir, "{}_{}_input.png".format(batch_idx, b)), rgb_input[1]) |
| imageio.imwrite(os.path.join(self.logger.log_dir, "{}_{}_rec.png".format(batch_idx, b)), rgb[1]) |
| imageio.imwrite(os.path.join(self.logger.log_dir, "{}_{}_gt.png".format(batch_idx, b)), rgb_gt[1]) |
|
|
| psnr_list += cur_psnr_list |
| psnr_input_list += cur_psnr_list_input |
| psnr_rec_list += cur_psnr_list_rec |
|
|
| self.log("test/psnr_input_gt", torch.Tensor(psnr_input_list).mean(), prog_bar=True) |
| self.log("test/psnr_input_rec", torch.Tensor(psnr_rec_list).mean(), prog_bar=True) |
| self.log("test/psnr_rec_gt", torch.Tensor(psnr_list).mean(), prog_bar=True) |
|
|
| def on_test_epoch_end(self): |
| mean = self.psum / self.count |
| mean_min = self.psum_min / self.len_dset |
| mean_max = self.psum_max / self.len_dset |
| var = (self.psum_sq / self.count) - (mean ** 2) |
| std = torch.sqrt(var) |
|
|
| print("mean min: {}".format(mean_min)) |
| print("mean max: {}".format(mean_max)) |
| print("mean: {}".format(mean)) |
| print("std: {}".format(std)) |
|
|
|
|
| class AutoencoderKLRollOut3DAwareMeanPool(AutoencoderKL): |
| def __init__(self, *args, **kwargs): |
| try: |
| ckpt_path = kwargs['ckpt_path'] |
| kwargs['ckpt_path'] = None |
| except: |
| ckpt_path = None |
| |
| super().__init__(*args, **kwargs) |
| self.psum = torch.zeros([1]) |
| self.psum_sq = torch.zeros([1]) |
| self.psum_min = torch.zeros([1]) |
| self.psum_max = torch.zeros([1]) |
| self.count = 0 |
| self.len_dset = 0 |
|
|
| ddconfig = kwargs['ddconfig'] |
| ddconfig['z_channels'] *= 3 |
| self.decoder = Decoder(**ddconfig) |
| self.post_quant_conv = torch.nn.Conv2d(kwargs['embed_dim'] * 3, ddconfig["z_channels"], 1) |
|
|
| if ckpt_path is not None: |
| self.init_from_ckpt(ckpt_path) |
|
|
| def rollout(self, triplane): |
| res = triplane.shape[-1] |
| ch = triplane.shape[1] |
| triplane = triplane.reshape(-1, 3, ch//3, res, res).permute(0, 2, 3, 1, 4).reshape(-1, ch//3, res, 3 * res) |
| return triplane |
|
|
| def to3daware(self, triplane): |
| res = triplane.shape[-2] |
| plane1 = triplane[..., :res] |
| plane2 = triplane[..., res:2*res] |
| plane3 = triplane[..., 2*res:3*res] |
|
|
| x_mp = torch.nn.AvgPool2d((res, 1)) |
| y_mp = torch.nn.AvgPool2d((1, res)) |
| x_mp_rep = lambda i: x_mp(i).repeat(1, 1, res, 1).permute(0, 1, 3, 2) |
| y_mp_rep = lambda i: y_mp(i).repeat(1, 1, 1, res).permute(0, 1, 3, 2) |
| |
| plane21 = x_mp_rep(plane2) |
| plane31 = torch.flip(y_mp_rep(plane3), (3,)) |
| new_plane1 = torch.cat([plane1, plane21, plane31], 1) |
| |
| plane12 = y_mp_rep(plane1) |
| plane32 = x_mp_rep(plane3) |
| new_plane2 = torch.cat([plane2, plane12, plane32], 1) |
| |
| plane13 = torch.flip(x_mp_rep(plane1), (2,)) |
| plane23 = y_mp_rep(plane2) |
| new_plane3 = torch.cat([plane3, plane13, plane23], 1) |
|
|
| new_plane = torch.cat([new_plane1, new_plane2, new_plane3], -1).contiguous() |
| return new_plane |
|
|
| def unrollout(self, triplane): |
| res = triplane.shape[-2] |
| ch = 3 * triplane.shape[1] |
| triplane = triplane.reshape(-1, ch//3, res, 3, res).permute(0, 3, 1, 2, 4).reshape(-1, ch, res, res) |
| return triplane |
|
|
| def encode(self, x, rollout=False): |
| if rollout: |
| x = self.to3daware(self.rollout(x)) |
| h = self.encoder(x) |
| moments = self.quant_conv(h) |
| posterior = DiagonalGaussianDistribution(moments) |
| return posterior |
|
|
| def decode(self, z, unrollout=False): |
| z = self.to3daware(z) |
| z = self.post_quant_conv(z) |
| dec = self.decoder(z) |
| if unrollout: |
| dec = self.unrollout(dec) |
| return dec |
|
|
| def forward(self, input, sample_posterior=True): |
| posterior = self.encode(input) |
| if sample_posterior: |
| z = posterior.sample() |
| else: |
| z = posterior.mode() |
| dec = self.decode(z) |
| return dec, posterior |
|
|
| def training_step(self, batch, batch_idx): |
| inputs = self.rollout(batch['triplane']) |
| reconstructions, posterior = self(self.to3daware(inputs)) |
| aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, prefix='train/') |
| self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False) |
| return aeloss |
|
|
| def validation_step(self, batch, batch_idx): |
| inputs = self.rollout(batch['triplane']) |
| reconstructions, posterior = self(self.to3daware(inputs), sample_posterior=False) |
| aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, prefix='val/') |
| self.log_dict(log_dict_ae) |
|
|
| assert not self.norm |
| reconstructions = self.unrollout(reconstructions) |
| psnr_list = [] |
| psnr_input_list = [] |
| psnr_rec_list = [] |
| batch_size = inputs.shape[0] |
| for b in range(batch_size): |
| if self.renderer_type == 'nerf': |
| rgb_input, cur_psnr_list_input = self.render_triplane( |
| batch['triplane_ori'][b:b+1], batch['batch_rays'][b], batch['img_flat'][b], |
| batch['near'][b].unsqueeze(-1), batch['far'][b].unsqueeze(-1) |
| ) |
| rgb, cur_psnr_list = self.render_triplane( |
| reconstructions[b:b+1], batch['batch_rays'][b], batch['img_flat'][b], |
| batch['near'][b].unsqueeze(-1), batch['far'][b].unsqueeze(-1) |
| ) |
| elif self.renderer_type == 'eg3d': |
| rgb_input, cur_psnr_list_input = self.render_triplane_eg3d_decoder( |
| batch['triplane_ori'][b:b+1], batch['batch_rays'][b], batch['img'][b], |
| ) |
| rgb, cur_psnr_list = self.render_triplane_eg3d_decoder( |
| reconstructions[b:b+1], batch['batch_rays'][b], batch['img'][b], |
| ) |
| else: |
| raise NotImplementedError |
|
|
| cur_psnr_list_rec = [] |
| for i in range(rgb.shape[0]): |
| cur_psnr_list_rec.append(mse2psnr(img2mse(rgb_input[i], rgb[i]))) |
|
|
| rgb_input = to8b(rgb_input.detach().cpu().numpy()) |
| rgb_gt = to8b(batch['img'][b].detach().cpu().numpy()) |
| rgb = to8b(rgb.detach().cpu().numpy()) |
| |
| if b % 4 == 0 and batch_idx < 10: |
| rgb_all = np.concatenate([rgb_gt[1], rgb_input[1], rgb[1]], 1) |
| self.logger.experiment.log({ |
| "val/vis": [wandb.Image(rgb_all)] |
| }) |
|
|
| psnr_list += cur_psnr_list |
| psnr_input_list += cur_psnr_list_input |
| psnr_rec_list += cur_psnr_list_rec |
|
|
| self.log("val/psnr_input_gt", torch.Tensor(psnr_input_list).mean(), prog_bar=True) |
| self.log("val/psnr_input_rec", torch.Tensor(psnr_rec_list).mean(), prog_bar=True) |
| self.log("val/psnr_rec_gt", torch.Tensor(psnr_list).mean(), prog_bar=True) |
|
|
| return self.log_dict |
|
|
| def to_rgb(self, plane): |
| x = plane.float() |
| if not hasattr(self, "colorize"): |
| self.colorize = torch.randn(3, x.shape[1], 1, 1).to(x) |
| x = torch.nn.functional.conv2d(x, weight=self.colorize) |
| x = ((x - x.min()) / (x.max() - x.min()) * 255.).permute(0, 2, 3, 1).detach().cpu().numpy().astype(np.uint8) |
| return x |
|
|
| def to_rgb_triplane(self, plane): |
| x = plane.float() |
| if not hasattr(self, "colorize_triplane"): |
| self.colorize_triplane = torch.randn(3, x.shape[1], 1, 1).to(x) |
| x = torch.nn.functional.conv2d(x, weight=self.colorize_triplane) |
| x = ((x - x.min()) / (x.max() - x.min()) * 255.).permute(0, 2, 3, 1).detach().cpu().numpy().astype(np.uint8) |
| return x |
| |
| def to_rgb_3daware(self, plane): |
| x = plane.float() |
| if not hasattr(self, "colorize_3daware"): |
| self.colorize_3daware = torch.randn(3, x.shape[1], 1, 1).to(x) |
| x = torch.nn.functional.conv2d(x, weight=self.colorize_3daware) |
| x = ((x - x.min()) / (x.max() - x.min()) * 255.).permute(0, 2, 3, 1).detach().cpu().numpy().astype(np.uint8) |
| return x |
|
|
| def test_step(self, batch, batch_idx): |
| inputs = self.rollout(batch['triplane']) |
| reconstructions, posterior = self(self.to3daware(inputs), sample_posterior=False) |
| aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, prefix='test/') |
| self.log_dict(log_dict_ae) |
|
|
| batch_size = inputs.shape[0] |
| psnr_list = [] |
| psnr_input_list = [] |
| psnr_rec_list = [] |
|
|
| z = posterior.mode() |
| colorize_z = self.to_rgb(z)[0] |
| colorize_triplane_input = self.to_rgb_triplane(inputs)[0] |
| colorize_triplane_output = self.to_rgb_triplane(reconstructions)[0] |
| colorize_triplane_rollout_3daware = self.to_rgb_3daware(self.to3daware(inputs))[0] |
| res = inputs.shape[1] |
| colorize_triplane_rollout_3daware_1 = self.to_rgb_triplane(self.to3daware(inputs)[:,res:2*res])[0] |
| colorize_triplane_rollout_3daware_2 = self.to_rgb_triplane(self.to3daware(inputs)[:,2*res:3*res])[0] |
| if batch_idx < 10: |
| imageio.imwrite(os.path.join(self.logger.log_dir, "colorize_z_{}.png".format(batch_idx)), colorize_z) |
| imageio.imwrite(os.path.join(self.logger.log_dir, "colorize_input_{}.png".format(batch_idx)), colorize_triplane_input) |
| imageio.imwrite(os.path.join(self.logger.log_dir, "colorize_output_{}.png".format(batch_idx)), colorize_triplane_output) |
| imageio.imwrite(os.path.join(self.logger.log_dir, "colorize_input_3daware_{}.png".format(batch_idx)), colorize_triplane_rollout_3daware) |
| imageio.imwrite(os.path.join(self.logger.log_dir, "colorize_input_3daware_{}_1.png".format(batch_idx)), colorize_triplane_rollout_3daware_1) |
| imageio.imwrite(os.path.join(self.logger.log_dir, "colorize_input_3daware_{}_2.png".format(batch_idx)), colorize_triplane_rollout_3daware_2) |
|
|
| reconstructions = self.unrollout(reconstructions) |
|
|
| if self.psum.device != z.device: |
| self.psum = self.psum.to(z.device) |
| self.psum_sq = self.psum_sq.to(z.device) |
| self.psum_min = self.psum_min.to(z.device) |
| self.psum_max = self.psum_max.to(z.device) |
| self.psum += z.sum() |
| self.psum_sq += (z ** 2).sum() |
| self.psum_min += z.reshape(-1).min(-1)[0] |
| self.psum_max += z.reshape(-1).max(-1)[0] |
| assert len(z.shape) == 4 |
| self.count += z.shape[0] * z.shape[1] * z.shape[2] * z.shape[3] |
| self.len_dset += 1 |
|
|
| if self.norm: |
| assert NotImplementedError |
| else: |
| reconstructions_unnormalize = reconstructions |
|
|
| for b in range(batch_size): |
| if self.renderer_type == 'nerf': |
| rgb_input, cur_psnr_list_input = self.render_triplane( |
| batch['triplane_ori'][b:b+1], batch['batch_rays'][b], batch['img_flat'][b], |
| batch['near'][b].unsqueeze(-1), batch['far'][b].unsqueeze(-1) |
| ) |
| rgb, cur_psnr_list = self.render_triplane( |
| reconstructions_unnormalize[b:b+1], batch['batch_rays'][b], batch['img_flat'][b], |
| batch['near'][b].unsqueeze(-1), batch['far'][b].unsqueeze(-1) |
| ) |
| elif self.renderer_type == 'eg3d': |
| rgb_input, cur_psnr_list_input = self.render_triplane_eg3d_decoder( |
| batch['triplane_ori'][b:b+1], batch['batch_rays'][b], batch['img'][b], |
| ) |
| rgb, cur_psnr_list = self.render_triplane_eg3d_decoder( |
| reconstructions_unnormalize[b:b+1], batch['batch_rays'][b], batch['img'][b], |
| ) |
| else: |
| raise NotImplementedError |
|
|
| cur_psnr_list_rec = [] |
| for i in range(rgb.shape[0]): |
| cur_psnr_list_rec.append(mse2psnr(img2mse(rgb_input[i], rgb[i]))) |
|
|
| rgb_input = to8b(rgb_input.detach().cpu().numpy()) |
| rgb_gt = to8b(batch['img'][b].detach().cpu().numpy()) |
| rgb = to8b(rgb.detach().cpu().numpy()) |
| |
| if batch_idx < 10: |
| imageio.imwrite(os.path.join(self.logger.log_dir, "{}_{}_input.png".format(batch_idx, b)), rgb_input[1]) |
| imageio.imwrite(os.path.join(self.logger.log_dir, "{}_{}_rec.png".format(batch_idx, b)), rgb[1]) |
| imageio.imwrite(os.path.join(self.logger.log_dir, "{}_{}_gt.png".format(batch_idx, b)), rgb_gt[1]) |
|
|
| psnr_list += cur_psnr_list |
| psnr_input_list += cur_psnr_list_input |
| psnr_rec_list += cur_psnr_list_rec |
|
|
| self.log("test/psnr_input_gt", torch.Tensor(psnr_input_list).mean(), prog_bar=True) |
| self.log("test/psnr_input_rec", torch.Tensor(psnr_rec_list).mean(), prog_bar=True) |
| self.log("test/psnr_rec_gt", torch.Tensor(psnr_list).mean(), prog_bar=True) |
|
|
| def on_test_epoch_end(self): |
| mean = self.psum / self.count |
| mean_min = self.psum_min / self.len_dset |
| mean_max = self.psum_max / self.len_dset |
| var = (self.psum_sq / self.count) - (mean ** 2) |
| std = torch.sqrt(var) |
|
|
| print("mean min: {}".format(mean_min)) |
| print("mean max: {}".format(mean_max)) |
| print("mean: {}".format(mean)) |
| print("std: {}".format(std)) |
|
|
|
|
| class AutoencoderKLGroupConv(AutoencoderKL): |
| def __init__(self, *args, **kwargs): |
| try: |
| ckpt_path = kwargs['ckpt_path'] |
| kwargs['ckpt_path'] = None |
| except: |
| ckpt_path = None |
|
|
| super().__init__(*args, **kwargs) |
| self.latent_list = [] |
| self.psum = torch.zeros([1]) |
| self.psum_sq = torch.zeros([1]) |
| self.psum_min = torch.zeros([1]) |
| self.psum_max = torch.zeros([1]) |
| self.count = 0 |
| self.len_dset = 0 |
|
|
| ddconfig = kwargs['ddconfig'] |
| |
| del self.decoder |
| del self.encoder |
| self.encoder = Encoder_GroupConv(**ddconfig) |
| self.decoder = Decoder_GroupConv(**ddconfig) |
|
|
| if "mean" in ddconfig: |
| print("Using mean std!!") |
| self.triplane_mean = torch.Tensor(ddconfig['mean']).reshape(-1).unsqueeze(0).unsqueeze(-1).unsqueeze(-1).float() |
| self.triplane_std = torch.Tensor(ddconfig['std']).reshape(-1).unsqueeze(0).unsqueeze(-1).unsqueeze(-1).float() |
| else: |
| self.triplane_mean = None |
| self.triplane_std = None |
|
|
| if ckpt_path is not None: |
| self.init_from_ckpt(ckpt_path) |
|
|
| def rollout(self, triplane): |
| res = triplane.shape[-1] |
| ch = triplane.shape[1] |
| triplane = triplane.reshape(-1, 3, ch//3, res, res).permute(0, 2, 3, 1, 4).reshape(-1, ch//3, res, 3 * res) |
| return triplane |
|
|
| def to3daware(self, triplane): |
| res = triplane.shape[-2] |
| plane1 = triplane[..., :res] |
| plane2 = triplane[..., res:2*res] |
| plane3 = triplane[..., 2*res:3*res] |
|
|
| x_mp = torch.nn.MaxPool2d((res, 1)) |
| y_mp = torch.nn.MaxPool2d((1, res)) |
| x_mp_rep = lambda i: x_mp(i).repeat(1, 1, res, 1).permute(0, 1, 3, 2) |
| y_mp_rep = lambda i: y_mp(i).repeat(1, 1, 1, res).permute(0, 1, 3, 2) |
| |
| plane21 = x_mp_rep(plane2) |
| plane31 = torch.flip(y_mp_rep(plane3), (3,)) |
| new_plane1 = torch.cat([plane1, plane21, plane31], 1) |
| |
| plane12 = y_mp_rep(plane1) |
| plane32 = x_mp_rep(plane3) |
| new_plane2 = torch.cat([plane2, plane12, plane32], 1) |
| |
| plane13 = torch.flip(x_mp_rep(plane1), (2,)) |
| plane23 = y_mp_rep(plane2) |
| new_plane3 = torch.cat([plane3, plane13, plane23], 1) |
|
|
| new_plane = torch.cat([new_plane1, new_plane2, new_plane3], -1).contiguous() |
| return new_plane |
|
|
| def unrollout(self, triplane): |
| res = triplane.shape[-2] |
| ch = 3 * triplane.shape[1] |
| triplane = triplane.reshape(-1, ch//3, res, 3, res).permute(0, 3, 1, 2, 4).reshape(-1, ch, res, res) |
| return triplane |
|
|
| def encode(self, x, rollout=False): |
| if rollout: |
| |
| x = self.rollout(x) |
| if self.triplane_mean is not None: |
| x = (x - self.triplane_mean.to(x.device)) / self.triplane_std.to(x.device) |
| h = self.encoder(x) |
| moments = self.quant_conv(h) |
| posterior = DiagonalGaussianDistribution(moments) |
| return posterior |
|
|
| def decode(self, z, unrollout=False): |
| |
| z = self.post_quant_conv(z) |
| dec = self.decoder(z) |
| if self.triplane_mean is not None: |
| dec = dec * self.triplane_std.to(dec.device) + self.triplane_mean.to(dec.device) |
| if unrollout: |
| dec = self.unrollout(dec) |
| return dec |
|
|
| def forward(self, input, sample_posterior=True): |
| posterior = self.encode(input) |
| if sample_posterior: |
| z = posterior.sample() |
| else: |
| z = posterior.mode() |
| dec = self.decode(z) |
| return dec, posterior |
|
|
| def training_step(self, batch, batch_idx): |
| inputs = self.rollout(batch['triplane']) |
| reconstructions, posterior = self(inputs) |
| aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, prefix='train/', batch=batch) |
| self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False) |
| return aeloss |
|
|
| def validation_step(self, batch, batch_idx): |
| inputs = self.rollout(batch['triplane']) |
| reconstructions, posterior = self(inputs, sample_posterior=False) |
| aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, prefix='val/', batch=None) |
| self.log_dict(log_dict_ae) |
|
|
| z = posterior.mode() |
| colorize_z = self.to_rgb(z)[0] |
| assert not self.norm |
| reconstructions = self.unrollout(reconstructions) |
| psnr_list = [] |
| psnr_input_list = [] |
| psnr_rec_list = [] |
| batch_size = inputs.shape[0] |
| for b in range(batch_size): |
| if self.renderer_type == 'nerf': |
| rgb_input, cur_psnr_list_input = self.render_triplane( |
| batch['triplane_ori'][b:b+1], batch['batch_rays'][b], batch['img_flat'][b], |
| batch['near'][b].unsqueeze(-1), batch['far'][b].unsqueeze(-1) |
| ) |
| rgb, cur_psnr_list = self.render_triplane( |
| reconstructions[b:b+1], batch['batch_rays'][b], batch['img_flat'][b], |
| batch['near'][b].unsqueeze(-1), batch['far'][b].unsqueeze(-1) |
| ) |
| elif self.renderer_type == 'eg3d': |
| rgb_input, cur_psnr_list_input = self.render_triplane_eg3d_decoder( |
| batch['triplane_ori'][b:b+1], batch['batch_rays'][b], batch['img'][b], |
| ) |
| rgb, cur_psnr_list = self.render_triplane_eg3d_decoder( |
| reconstructions[b:b+1], batch['batch_rays'][b], batch['img'][b], |
| ) |
| else: |
| raise NotImplementedError |
|
|
| cur_psnr_list_rec = [] |
| for i in range(rgb.shape[0]): |
| cur_psnr_list_rec.append(mse2psnr(img2mse(rgb_input[i], rgb[i]))) |
|
|
| rgb_input = to8b(rgb_input.detach().cpu().numpy()) |
| rgb_gt = to8b(batch['img'][b].detach().cpu().numpy()) |
| rgb = to8b(rgb.detach().cpu().numpy()) |
|
|
| rgb_input = np.stack([rgb_input[..., 2], rgb_input[..., 1], rgb_input[..., 0]], -1) |
| rgb = np.stack([rgb[..., 2], rgb[..., 1], rgb[..., 0]], -1) |
| |
| if b % 2 == 0 and batch_idx < 10: |
| rgb_all = np.concatenate([rgb_gt[1], rgb_input[1], rgb[1]], 1) |
| self.logger.experiment.log({ |
| "val/vis": [wandb.Image(rgb_all)], |
| "val/latent_vis": [wandb.Image(colorize_z)] |
| }) |
|
|
| psnr_list += cur_psnr_list |
| psnr_input_list += cur_psnr_list_input |
| psnr_rec_list += cur_psnr_list_rec |
|
|
| self.log("val/psnr_input_gt", torch.Tensor(psnr_input_list).mean(), prog_bar=True) |
| self.log("val/psnr_input_rec", torch.Tensor(psnr_rec_list).mean(), prog_bar=True) |
| self.log("val/psnr_rec_gt", torch.Tensor(psnr_list).mean(), prog_bar=True) |
|
|
| return self.log_dict |
|
|
| def to_rgb(self, plane): |
| x = plane.float() |
| if not hasattr(self, "colorize"): |
| self.colorize = torch.randn(3, x.shape[1], 1, 1).to(x) |
| x = torch.nn.functional.conv2d(x, weight=self.colorize) |
| x = ((x - x.min()) / (x.max() - x.min()) * 255.).permute(0, 2, 3, 1).detach().cpu().numpy().astype(np.uint8) |
| return x |
|
|
| def to_rgb_triplane(self, plane): |
| x = plane.float() |
| if not hasattr(self, "colorize_triplane"): |
| self.colorize_triplane = torch.randn(3, x.shape[1], 1, 1).to(x) |
| x = torch.nn.functional.conv2d(x, weight=self.colorize_triplane) |
| x = ((x - x.min()) / (x.max() - x.min()) * 255.).permute(0, 2, 3, 1).detach().cpu().numpy().astype(np.uint8) |
| return x |
| |
| def to_rgb_3daware(self, plane): |
| x = plane.float() |
| if not hasattr(self, "colorize_3daware"): |
| self.colorize_3daware = torch.randn(3, x.shape[1], 1, 1).to(x) |
| x = torch.nn.functional.conv2d(x, weight=self.colorize_3daware) |
| x = ((x - x.min()) / (x.max() - x.min()) * 255.).permute(0, 2, 3, 1).detach().cpu().numpy().astype(np.uint8) |
| return x |
|
|
| def test_step(self, batch, batch_idx): |
| inputs = self.rollout(batch['triplane']) |
| reconstructions, posterior = self(inputs, sample_posterior=False) |
| aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, prefix='test/', batch=None) |
| self.log_dict(log_dict_ae) |
|
|
| batch_size = inputs.shape[0] |
| psnr_list = [] |
| psnr_input_list = [] |
| psnr_rec_list = [] |
|
|
| z = posterior.mode() |
| colorize_z = self.to_rgb(z)[0] |
| colorize_triplane_input = self.to_rgb_triplane(inputs)[0] |
| colorize_triplane_output = self.to_rgb_triplane(reconstructions)[0] |
|
|
| import os |
| import random |
| import string |
| |
| z_np = inputs.detach().cpu().numpy() |
| fname = ''.join(random.choices(string.ascii_uppercase + string.digits, k=8)) + '.npy' |
| with open(os.path.join('/mnt/lustre/hongfangzhou.p/AE3D/tmp', fname), 'wb') as f: |
| np.save(f, z_np) |
|
|
| |
| |
| |
| |
| if batch_idx < 0: |
| imageio.imwrite(os.path.join(self.logger.log_dir, "colorize_z_{}.png".format(batch_idx)), colorize_z) |
| imageio.imwrite(os.path.join(self.logger.log_dir, "colorize_input_{}.png".format(batch_idx)), colorize_triplane_input) |
| imageio.imwrite(os.path.join(self.logger.log_dir, "colorize_output_{}.png".format(batch_idx)), colorize_triplane_output) |
| |
| |
| |
|
|
| np_z = z.detach().cpu().numpy() |
| |
| |
|
|
| self.latent_list.append(np_z) |
|
|
| reconstructions = self.unrollout(reconstructions) |
|
|
| if self.psum.device != z.device: |
| self.psum = self.psum.to(z.device) |
| self.psum_sq = self.psum_sq.to(z.device) |
| self.psum_min = self.psum_min.to(z.device) |
| self.psum_max = self.psum_max.to(z.device) |
| self.psum += z.sum() |
| self.psum_sq += (z ** 2).sum() |
| self.psum_min += z.reshape(-1).min(-1)[0] |
| self.psum_max += z.reshape(-1).max(-1)[0] |
| assert len(z.shape) == 4 |
| self.count += z.shape[0] * z.shape[1] * z.shape[2] * z.shape[3] |
| self.len_dset += 1 |
|
|
| if self.norm: |
| assert NotImplementedError |
| else: |
| reconstructions_unnormalize = reconstructions |
|
|
| if True: |
| for b in range(batch_size): |
| if self.renderer_type == 'nerf': |
| rgb_input, cur_psnr_list_input = self.render_triplane( |
| batch['triplane_ori'][b:b+1], batch['batch_rays'][b], batch['img_flat'][b], |
| batch['near'][b].unsqueeze(-1), batch['far'][b].unsqueeze(-1) |
| ) |
| rgb, cur_psnr_list = self.render_triplane( |
| reconstructions_unnormalize[b:b+1], batch['batch_rays'][b], batch['img_flat'][b], |
| batch['near'][b].unsqueeze(-1), batch['far'][b].unsqueeze(-1) |
| ) |
| elif self.renderer_type == 'eg3d': |
| rgb_input, cur_psnr_list_input = self.render_triplane_eg3d_decoder( |
| batch['triplane_ori'][b:b+1], batch['batch_rays'][b], batch['img'][b], |
| ) |
| rgb, cur_psnr_list = self.render_triplane_eg3d_decoder( |
| reconstructions_unnormalize[b:b+1], batch['batch_rays'][b], batch['img'][b], |
| ) |
| else: |
| raise NotImplementedError |
|
|
| cur_psnr_list_rec = [] |
| for i in range(rgb.shape[0]): |
| cur_psnr_list_rec.append(mse2psnr(img2mse(rgb_input[i], rgb[i]))) |
|
|
| rgb_input = to8b(rgb_input.detach().cpu().numpy()) |
| rgb_gt = to8b(batch['img'][b].detach().cpu().numpy()) |
| rgb = to8b(rgb.detach().cpu().numpy()) |
| |
| if batch_idx < 10: |
| imageio.imwrite(os.path.join(self.logger.log_dir, "{}_{}_input.png".format(batch_idx, b)), rgb_input[1]) |
| imageio.imwrite(os.path.join(self.logger.log_dir, "{}_{}_rec.png".format(batch_idx, b)), rgb[1]) |
| imageio.imwrite(os.path.join(self.logger.log_dir, "{}_{}_gt.png".format(batch_idx, b)), rgb_gt[1]) |
|
|
| psnr_list += cur_psnr_list |
| psnr_input_list += cur_psnr_list_input |
| psnr_rec_list += cur_psnr_list_rec |
|
|
| self.log("test/psnr_input_gt", torch.Tensor(psnr_input_list).mean(), prog_bar=True) |
| self.log("test/psnr_input_rec", torch.Tensor(psnr_rec_list).mean(), prog_bar=True) |
| self.log("test/psnr_rec_gt", torch.Tensor(psnr_list).mean(), prog_bar=True) |
|
|
| def on_test_epoch_end(self): |
| mean = self.psum / self.count |
| mean_min = self.psum_min / self.len_dset |
| mean_max = self.psum_max / self.len_dset |
| var = (self.psum_sq / self.count) - (mean ** 2) |
| std = torch.sqrt(var) |
|
|
| print("mean min: {}".format(mean_min)) |
| print("mean max: {}".format(mean_max)) |
| print("mean: {}".format(mean)) |
| print("std: {}".format(std)) |
|
|
| latent = np.concatenate(self.latent_list) |
| q75, q25 = np.percentile(latent.reshape(-1), [75 ,25]) |
| median = np.median(latent.reshape(-1)) |
| iqr = q75 - q25 |
| norm_iqr = iqr * 0.7413 |
| print("Norm IQR: {}".format(norm_iqr)) |
| print("Inverse Norm IQR: {}".format(1/norm_iqr)) |
| print("Median: {}".format(median)) |
|
|
| def loss(self, inputs, reconstructions, posteriors, prefix, batch=None): |
| reconstructions = reconstructions.contiguous() |
| |
| |
| rec_loss = F.mse_loss(inputs.contiguous(), reconstructions) |
| kl_loss = posteriors.kl() |
| |
| kl_loss = kl_loss.mean() |
| loss = self.lossconfig.rec_weight * rec_loss + self.lossconfig.kl_weight * kl_loss |
|
|
| ret_dict = { |
| prefix+'mean_rec_loss': torch.abs(inputs.contiguous() - reconstructions.contiguous()).mean().detach(), |
| prefix+'rec_loss': rec_loss, |
| prefix+'kl_loss': kl_loss, |
| prefix+'loss': loss, |
| prefix+'mean': posteriors.mean.mean(), |
| prefix+'logvar': posteriors.logvar.mean(), |
| } |
|
|
|
|
| latent = posteriors.mean |
| ret_dict[prefix + 'latent_max'] = latent.max() |
| ret_dict[prefix + 'latent_min'] = latent.min() |
|
|
| render_weight = self.lossconfig.get("render_weight", 0) |
| tv_weight = self.lossconfig.get("tv_weight", 0) |
| l1_weight = self.lossconfig.get("l1_weight", 0) |
| latent_tv_weight = self.lossconfig.get("latent_tv_weight", 0) |
| latent_l1_weight = self.lossconfig.get("latent_l1_weight", 0) |
|
|
| triplane_rec = self.unrollout(reconstructions) |
| if render_weight > 0 and batch is not None: |
| rgb_rendered, target = self.render_triplane_eg3d_decoder_sample_pixel(triplane_rec, batch['batch_rays'], batch['img']) |
| |
| render_loss = F.mse_loss(rgb_rendered, target) |
| loss += render_weight * render_loss |
| ret_dict[prefix + 'render_loss'] = render_loss |
| if tv_weight > 0: |
| tvloss_y = F.mse_loss(triplane_rec[:, :, :-1], triplane_rec[:, :, 1:]) |
| tvloss_x = F.mse_loss(triplane_rec[:, :, :, :-1], triplane_rec[:, :, :, 1:]) |
| tvloss = tvloss_y + tvloss_x |
| loss += tv_weight * tvloss |
| ret_dict[prefix + 'tv_loss'] = tvloss |
| if l1_weight > 0: |
| l1 = (triplane_rec ** 2).mean() |
| loss += l1_weight * l1 |
| ret_dict[prefix + 'l1_loss'] = l1 |
| if latent_tv_weight > 0: |
| latent = posteriors.mean |
| latent_tv_y = F.mse_loss(latent[:, :, :-1], latent[:, :, 1:]) |
| latent_tv_x = F.mse_loss(latent[:, :, :, :-1], latent[:, :, :, 1:]) |
| latent_tv_loss = latent_tv_y + latent_tv_x |
| loss += latent_tv_loss * latent_tv_weight |
| ret_dict[prefix + 'latent_tv_loss'] = latent_tv_loss |
| if latent_l1_weight > 0: |
| latent = posteriors.mean |
| latent_l1_loss = (latent ** 2).mean() |
| loss += latent_l1_loss * latent_l1_weight |
| ret_dict[prefix + 'latent_l1_loss'] = latent_l1_loss |
|
|
| return loss, ret_dict |
|
|
|
|
| class AutoencoderKLGroupConvLateFusion(AutoencoderKL): |
| def __init__(self, *args, **kwargs): |
| try: |
| ckpt_path = kwargs['ckpt_path'] |
| kwargs['ckpt_path'] = None |
| except: |
| ckpt_path = None |
|
|
| super().__init__(*args, **kwargs) |
| self.latent_list = [] |
| self.psum = torch.zeros([1]) |
| self.psum_sq = torch.zeros([1]) |
| self.psum_min = torch.zeros([1]) |
| self.psum_max = torch.zeros([1]) |
| self.count = 0 |
| self.len_dset = 0 |
|
|
| ddconfig = kwargs['ddconfig'] |
| del self.decoder |
| del self.encoder |
| self.encoder = Encoder_GroupConv_LateFusion(**ddconfig) |
| self.decoder = Decoder_GroupConv_LateFusion(**ddconfig) |
|
|
| if ckpt_path is not None: |
| self.init_from_ckpt(ckpt_path) |
|
|
| def rollout(self, triplane): |
| res = triplane.shape[-1] |
| ch = triplane.shape[1] |
| triplane = triplane.reshape(-1, 3, ch//3, res, res).permute(0, 2, 3, 1, 4).reshape(-1, ch//3, res, 3 * res) |
| return triplane |
|
|
| def to3daware(self, triplane): |
| res = triplane.shape[-2] |
| plane1 = triplane[..., :res] |
| plane2 = triplane[..., res:2*res] |
| plane3 = triplane[..., 2*res:3*res] |
|
|
| x_mp = torch.nn.MaxPool2d((res, 1)) |
| y_mp = torch.nn.MaxPool2d((1, res)) |
| x_mp_rep = lambda i: x_mp(i).repeat(1, 1, res, 1).permute(0, 1, 3, 2) |
| y_mp_rep = lambda i: y_mp(i).repeat(1, 1, 1, res).permute(0, 1, 3, 2) |
| |
| plane21 = x_mp_rep(plane2) |
| plane31 = torch.flip(y_mp_rep(plane3), (3,)) |
| new_plane1 = torch.cat([plane1, plane21, plane31], 1) |
| |
| plane12 = y_mp_rep(plane1) |
| plane32 = x_mp_rep(plane3) |
| new_plane2 = torch.cat([plane2, plane12, plane32], 1) |
| |
| plane13 = torch.flip(x_mp_rep(plane1), (2,)) |
| plane23 = y_mp_rep(plane2) |
| new_plane3 = torch.cat([plane3, plane13, plane23], 1) |
|
|
| new_plane = torch.cat([new_plane1, new_plane2, new_plane3], -1).contiguous() |
| return new_plane |
|
|
| def unrollout(self, triplane): |
| res = triplane.shape[-2] |
| ch = 3 * triplane.shape[1] |
| triplane = triplane.reshape(-1, ch//3, res, 3, res).permute(0, 3, 1, 2, 4).reshape(-1, ch, res, res) |
| return triplane |
|
|
| def encode(self, x, rollout=False): |
| if rollout: |
| x = self.rollout(x) |
| h = self.encoder(x) |
| moments = self.quant_conv(h) |
| posterior = DiagonalGaussianDistribution(moments) |
| return posterior |
|
|
| def decode(self, z, unrollout=False): |
| z = self.post_quant_conv(z) |
| dec = self.decoder(z) |
| if unrollout: |
| dec = self.unrollout(dec) |
| return dec |
|
|
| def forward(self, input, sample_posterior=True): |
| posterior = self.encode(input) |
| if sample_posterior: |
| z = posterior.sample() |
| else: |
| z = posterior.mode() |
| dec = self.decode(z) |
| return dec, posterior |
|
|
| def training_step(self, batch, batch_idx): |
| inputs = self.rollout(batch['triplane']) |
| reconstructions, posterior = self(inputs) |
| aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, prefix='train/', batch=batch) |
| self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False) |
| return aeloss |
|
|
| def validation_step(self, batch, batch_idx): |
| inputs = self.rollout(batch['triplane']) |
| reconstructions, posterior = self(inputs, sample_posterior=False) |
| aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, prefix='val/', batch=None) |
| self.log_dict(log_dict_ae) |
|
|
| assert not self.norm |
| reconstructions = self.unrollout(reconstructions) |
| psnr_list = [] |
| psnr_input_list = [] |
| psnr_rec_list = [] |
| batch_size = inputs.shape[0] |
| for b in range(batch_size): |
| if self.renderer_type == 'nerf': |
| rgb_input, cur_psnr_list_input = self.render_triplane( |
| batch['triplane_ori'][b:b+1], batch['batch_rays'][b], batch['img_flat'][b], |
| batch['near'][b].unsqueeze(-1), batch['far'][b].unsqueeze(-1) |
| ) |
| rgb, cur_psnr_list = self.render_triplane( |
| reconstructions[b:b+1], batch['batch_rays'][b], batch['img_flat'][b], |
| batch['near'][b].unsqueeze(-1), batch['far'][b].unsqueeze(-1) |
| ) |
| elif self.renderer_type == 'eg3d': |
| rgb_input, cur_psnr_list_input = self.render_triplane_eg3d_decoder( |
| batch['triplane_ori'][b:b+1], batch['batch_rays'][b], batch['img'][b], |
| ) |
| rgb, cur_psnr_list = self.render_triplane_eg3d_decoder( |
| reconstructions[b:b+1], batch['batch_rays'][b], batch['img'][b], |
| ) |
| else: |
| raise NotImplementedError |
|
|
| cur_psnr_list_rec = [] |
| for i in range(rgb.shape[0]): |
| cur_psnr_list_rec.append(mse2psnr(img2mse(rgb_input[i], rgb[i]))) |
|
|
| rgb_input = to8b(rgb_input.detach().cpu().numpy()) |
| rgb_gt = to8b(batch['img'][b].detach().cpu().numpy()) |
| rgb = to8b(rgb.detach().cpu().numpy()) |
| |
| if b % 4 == 0 and batch_idx < 10: |
| rgb_all = np.concatenate([rgb_gt[1], rgb_input[1], rgb[1]], 1) |
| self.logger.experiment.log({ |
| "val/vis": [wandb.Image(rgb_all)] |
| }) |
|
|
| psnr_list += cur_psnr_list |
| psnr_input_list += cur_psnr_list_input |
| psnr_rec_list += cur_psnr_list_rec |
|
|
| self.log("val/psnr_input_gt", torch.Tensor(psnr_input_list).mean(), prog_bar=True) |
| self.log("val/psnr_input_rec", torch.Tensor(psnr_rec_list).mean(), prog_bar=True) |
| self.log("val/psnr_rec_gt", torch.Tensor(psnr_list).mean(), prog_bar=True) |
|
|
| return self.log_dict |
|
|
| def to_rgb(self, plane): |
| x = plane.float() |
| if not hasattr(self, "colorize"): |
| self.colorize = torch.randn(3, x.shape[1], 1, 1).to(x) |
| x = torch.nn.functional.conv2d(x, weight=self.colorize) |
| x = ((x - x.min()) / (x.max() - x.min()) * 255.).permute(0, 2, 3, 1).detach().cpu().numpy().astype(np.uint8) |
| return x |
|
|
| def to_rgb_triplane(self, plane): |
| x = plane.float() |
| if not hasattr(self, "colorize_triplane"): |
| self.colorize_triplane = torch.randn(3, x.shape[1], 1, 1).to(x) |
| x = torch.nn.functional.conv2d(x, weight=self.colorize_triplane) |
| x = ((x - x.min()) / (x.max() - x.min()) * 255.).permute(0, 2, 3, 1).detach().cpu().numpy().astype(np.uint8) |
| return x |
| |
| def to_rgb_3daware(self, plane): |
| x = plane.float() |
| if not hasattr(self, "colorize_3daware"): |
| self.colorize_3daware = torch.randn(3, x.shape[1], 1, 1).to(x) |
| x = torch.nn.functional.conv2d(x, weight=self.colorize_3daware) |
| x = ((x - x.min()) / (x.max() - x.min()) * 255.).permute(0, 2, 3, 1).detach().cpu().numpy().astype(np.uint8) |
| return x |
|
|
| def test_step(self, batch, batch_idx): |
| inputs = self.rollout(batch['triplane']) |
| reconstructions, posterior = self(inputs, sample_posterior=False) |
| aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, prefix='test/', batch=None) |
| self.log_dict(log_dict_ae) |
|
|
| batch_size = inputs.shape[0] |
| psnr_list = [] |
| psnr_input_list = [] |
| psnr_rec_list = [] |
|
|
| z = posterior.mode() |
| colorize_z = self.to_rgb(z)[0] |
| colorize_triplane_input = self.to_rgb_triplane(inputs)[0] |
| colorize_triplane_output = self.to_rgb_triplane(reconstructions)[0] |
| |
| |
| |
| |
| if batch_idx < 10: |
| imageio.imwrite(os.path.join(self.logger.log_dir, "colorize_z_{}.png".format(batch_idx)), colorize_z) |
| imageio.imwrite(os.path.join(self.logger.log_dir, "colorize_input_{}.png".format(batch_idx)), colorize_triplane_input) |
| imageio.imwrite(os.path.join(self.logger.log_dir, "colorize_output_{}.png".format(batch_idx)), colorize_triplane_output) |
| |
| |
| |
|
|
| np_z = z.detach().cpu().numpy() |
| |
| |
|
|
| self.latent_list.append(np_z) |
|
|
| reconstructions = self.unrollout(reconstructions) |
|
|
| if self.psum.device != z.device: |
| self.psum = self.psum.to(z.device) |
| self.psum_sq = self.psum_sq.to(z.device) |
| self.psum_min = self.psum_min.to(z.device) |
| self.psum_max = self.psum_max.to(z.device) |
| self.psum += z.sum() |
| self.psum_sq += (z ** 2).sum() |
| self.psum_min += z.reshape(-1).min(-1)[0] |
| self.psum_max += z.reshape(-1).max(-1)[0] |
| assert len(z.shape) == 4 |
| self.count += z.shape[0] * z.shape[1] * z.shape[2] * z.shape[3] |
| self.len_dset += 1 |
|
|
| if self.norm: |
| assert NotImplementedError |
| else: |
| reconstructions_unnormalize = reconstructions |
|
|
| for b in range(batch_size): |
| if self.renderer_type == 'nerf': |
| rgb_input, cur_psnr_list_input = self.render_triplane( |
| batch['triplane_ori'][b:b+1], batch['batch_rays'][b], batch['img_flat'][b], |
| batch['near'][b].unsqueeze(-1), batch['far'][b].unsqueeze(-1) |
| ) |
| rgb, cur_psnr_list = self.render_triplane( |
| reconstructions_unnormalize[b:b+1], batch['batch_rays'][b], batch['img_flat'][b], |
| batch['near'][b].unsqueeze(-1), batch['far'][b].unsqueeze(-1) |
| ) |
| elif self.renderer_type == 'eg3d': |
| rgb_input, cur_psnr_list_input = self.render_triplane_eg3d_decoder( |
| batch['triplane_ori'][b:b+1], batch['batch_rays'][b], batch['img'][b], |
| ) |
| rgb, cur_psnr_list = self.render_triplane_eg3d_decoder( |
| reconstructions_unnormalize[b:b+1], batch['batch_rays'][b], batch['img'][b], |
| ) |
| else: |
| raise NotImplementedError |
|
|
| cur_psnr_list_rec = [] |
| for i in range(rgb.shape[0]): |
| cur_psnr_list_rec.append(mse2psnr(img2mse(rgb_input[i], rgb[i]))) |
|
|
| rgb_input = to8b(rgb_input.detach().cpu().numpy()) |
| rgb_gt = to8b(batch['img'][b].detach().cpu().numpy()) |
| rgb = to8b(rgb.detach().cpu().numpy()) |
| |
| if batch_idx < 10: |
| imageio.imwrite(os.path.join(self.logger.log_dir, "{}_{}_input.png".format(batch_idx, b)), rgb_input[1]) |
| imageio.imwrite(os.path.join(self.logger.log_dir, "{}_{}_rec.png".format(batch_idx, b)), rgb[1]) |
| imageio.imwrite(os.path.join(self.logger.log_dir, "{}_{}_gt.png".format(batch_idx, b)), rgb_gt[1]) |
|
|
| psnr_list += cur_psnr_list |
| psnr_input_list += cur_psnr_list_input |
| psnr_rec_list += cur_psnr_list_rec |
|
|
| self.log("test/psnr_input_gt", torch.Tensor(psnr_input_list).mean(), prog_bar=True) |
| self.log("test/psnr_input_rec", torch.Tensor(psnr_rec_list).mean(), prog_bar=True) |
| self.log("test/psnr_rec_gt", torch.Tensor(psnr_list).mean(), prog_bar=True) |
|
|
| def on_test_epoch_end(self): |
| mean = self.psum / self.count |
| mean_min = self.psum_min / self.len_dset |
| mean_max = self.psum_max / self.len_dset |
| var = (self.psum_sq / self.count) - (mean ** 2) |
| std = torch.sqrt(var) |
|
|
| print("mean min: {}".format(mean_min)) |
| print("mean max: {}".format(mean_max)) |
| print("mean: {}".format(mean)) |
| print("std: {}".format(std)) |
|
|
| latent = np.concatenate(self.latent_list) |
| q75, q25 = np.percentile(latent.reshape(-1), [75 ,25]) |
| median = np.median(latent.reshape(-1)) |
| iqr = q75 - q25 |
| norm_iqr = iqr * 0.7413 |
| print("Norm IQR: {}".format(norm_iqr)) |
| print("Inverse Norm IQR: {}".format(1/norm_iqr)) |
| print("Median: {}".format(median)) |
|
|
|
|
| from module.model_2d import ViTEncoder, ViTDecoder |
|
|
| class AutoencoderVIT(AutoencoderKL): |
| def __init__(self, *args, **kwargs): |
| try: |
| ckpt_path = kwargs['ckpt_path'] |
| kwargs['ckpt_path'] = None |
| except: |
| ckpt_path = None |
|
|
| super().__init__(*args, **kwargs) |
| self.latent_list = [] |
| self.psum = torch.zeros([1]) |
| self.psum_sq = torch.zeros([1]) |
| self.psum_min = torch.zeros([1]) |
| self.psum_max = torch.zeros([1]) |
| self.count = 0 |
| self.len_dset = 0 |
|
|
| ddconfig = kwargs['ddconfig'] |
| |
| del self.decoder |
| del self.encoder |
| del self.quant_conv |
| del self.post_quant_conv |
|
|
| assert ddconfig["z_channels"] == 256 |
| self.encoder = ViTEncoder( |
| image_size=(256, 256*3), |
| patch_size=(256//32, 256//32), |
| dim=768, |
| depth=12, |
| heads=12, |
| mlp_dim=3072, |
| channels=8) |
| self.decoder = ViTDecoder( |
| image_size=(256, 256*3), |
| patch_size=(256//32, 256//32), |
| dim=768, |
| depth=12, |
| heads=12, |
| mlp_dim=3072, |
| channels=8) |
|
|
| self.quant_conv = torch.nn.Conv2d(768, 2*self.embed_dim, 1) |
| self.post_quant_conv = torch.nn.Conv2d(self.embed_dim, 768, 1) |
|
|
| if "mean" in ddconfig: |
| print("Using mean std!!") |
| self.triplane_mean = torch.Tensor(ddconfig['mean']).reshape(-1).unsqueeze(0).unsqueeze(-1).unsqueeze(-1).float() |
| self.triplane_std = torch.Tensor(ddconfig['std']).reshape(-1).unsqueeze(0).unsqueeze(-1).unsqueeze(-1).float() |
| else: |
| self.triplane_mean = None |
| self.triplane_std = None |
|
|
| if ckpt_path is not None: |
| self.init_from_ckpt(ckpt_path) |
|
|
| def rollout(self, triplane): |
| res = triplane.shape[-1] |
| ch = triplane.shape[1] |
| triplane = triplane.reshape(-1, 3, ch//3, res, res).permute(0, 2, 3, 1, 4).reshape(-1, ch//3, res, 3 * res) |
| return triplane |
|
|
| def to3daware(self, triplane): |
| res = triplane.shape[-2] |
| plane1 = triplane[..., :res] |
| plane2 = triplane[..., res:2*res] |
| plane3 = triplane[..., 2*res:3*res] |
|
|
| x_mp = torch.nn.MaxPool2d((res, 1)) |
| y_mp = torch.nn.MaxPool2d((1, res)) |
| x_mp_rep = lambda i: x_mp(i).repeat(1, 1, res, 1).permute(0, 1, 3, 2) |
| y_mp_rep = lambda i: y_mp(i).repeat(1, 1, 1, res).permute(0, 1, 3, 2) |
| |
| plane21 = x_mp_rep(plane2) |
| plane31 = torch.flip(y_mp_rep(plane3), (3,)) |
| new_plane1 = torch.cat([plane1, plane21, plane31], 1) |
| |
| plane12 = y_mp_rep(plane1) |
| plane32 = x_mp_rep(plane3) |
| new_plane2 = torch.cat([plane2, plane12, plane32], 1) |
| |
| plane13 = torch.flip(x_mp_rep(plane1), (2,)) |
| plane23 = y_mp_rep(plane2) |
| new_plane3 = torch.cat([plane3, plane13, plane23], 1) |
|
|
| new_plane = torch.cat([new_plane1, new_plane2, new_plane3], -1).contiguous() |
| return new_plane |
|
|
| def unrollout(self, triplane): |
| res = triplane.shape[-2] |
| ch = 3 * triplane.shape[1] |
| triplane = triplane.reshape(-1, ch//3, res, 3, res).permute(0, 3, 1, 2, 4).reshape(-1, ch, res, res) |
| return triplane |
|
|
| def encode(self, x, rollout=False): |
| if rollout: |
| |
| x = self.rollout(x) |
| if self.triplane_mean is not None: |
| x = (x - self.triplane_mean.to(x.device)) / self.triplane_std.to(x.device) |
| h = self.encoder(x) |
| moments = self.quant_conv(h) |
| posterior = DiagonalGaussianDistribution(moments) |
| return posterior |
|
|
| def decode(self, z, unrollout=False): |
| |
| z = self.post_quant_conv(z) |
| dec = self.decoder(z) |
| if self.triplane_mean is not None: |
| dec = dec * self.triplane_std.to(dec.device) + self.triplane_mean.to(dec.device) |
| if unrollout: |
| dec = self.unrollout(dec) |
| return dec |
|
|
| def forward(self, input, sample_posterior=True): |
| posterior = self.encode(input) |
| if sample_posterior: |
| z = posterior.sample() |
| else: |
| z = posterior.mode() |
| dec = self.decode(z) |
| return dec, posterior |
|
|
| def training_step(self, batch, batch_idx): |
| inputs = self.rollout(batch['triplane']) |
| reconstructions, posterior = self(inputs) |
| aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, prefix='train/', batch=batch) |
| self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False) |
| return aeloss |
|
|
| def validation_step(self, batch, batch_idx): |
| inputs = self.rollout(batch['triplane']) |
| reconstructions, posterior = self(inputs, sample_posterior=False) |
| aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, prefix='val/', batch=None) |
| self.log_dict(log_dict_ae) |
|
|
| assert not self.norm |
| reconstructions = self.unrollout(reconstructions) |
| psnr_list = [] |
| psnr_input_list = [] |
| psnr_rec_list = [] |
| batch_size = inputs.shape[0] |
| for b in range(batch_size): |
| if self.renderer_type == 'nerf': |
| rgb_input, cur_psnr_list_input = self.render_triplane( |
| batch['triplane_ori'][b:b+1], batch['batch_rays'][b], batch['img_flat'][b], |
| batch['near'][b].unsqueeze(-1), batch['far'][b].unsqueeze(-1) |
| ) |
| rgb, cur_psnr_list = self.render_triplane( |
| reconstructions[b:b+1], batch['batch_rays'][b], batch['img_flat'][b], |
| batch['near'][b].unsqueeze(-1), batch['far'][b].unsqueeze(-1) |
| ) |
| elif self.renderer_type == 'eg3d': |
| rgb_input, cur_psnr_list_input = self.render_triplane_eg3d_decoder( |
| batch['triplane_ori'][b:b+1], batch['batch_rays'][b], batch['img'][b], |
| ) |
| rgb, cur_psnr_list = self.render_triplane_eg3d_decoder( |
| reconstructions[b:b+1], batch['batch_rays'][b], batch['img'][b], |
| ) |
| else: |
| raise NotImplementedError |
|
|
| cur_psnr_list_rec = [] |
| for i in range(rgb.shape[0]): |
| cur_psnr_list_rec.append(mse2psnr(img2mse(rgb_input[i], rgb[i]))) |
|
|
| rgb_input = to8b(rgb_input.detach().cpu().numpy()) |
| rgb_gt = to8b(batch['img'][b].detach().cpu().numpy()) |
| rgb = to8b(rgb.detach().cpu().numpy()) |
| |
| if b % 4 == 0 and batch_idx < 10: |
| rgb_all = np.concatenate([rgb_gt[1], rgb_input[1], rgb[1]], 1) |
| self.logger.experiment.log({ |
| "val/vis": [wandb.Image(rgb_all)] |
| }) |
|
|
| psnr_list += cur_psnr_list |
| psnr_input_list += cur_psnr_list_input |
| psnr_rec_list += cur_psnr_list_rec |
|
|
| self.log("val/psnr_input_gt", torch.Tensor(psnr_input_list).mean(), prog_bar=True) |
| self.log("val/psnr_input_rec", torch.Tensor(psnr_rec_list).mean(), prog_bar=True) |
| self.log("val/psnr_rec_gt", torch.Tensor(psnr_list).mean(), prog_bar=True) |
|
|
| return self.log_dict |
|
|
| def to_rgb(self, plane): |
| x = plane.float() |
| if not hasattr(self, "colorize"): |
| self.colorize = torch.randn(3, x.shape[1], 1, 1).to(x) |
| x = torch.nn.functional.conv2d(x, weight=self.colorize) |
| x = ((x - x.min()) / (x.max() - x.min()) * 255.).permute(0, 2, 3, 1).detach().cpu().numpy().astype(np.uint8) |
| return x |
|
|
| def to_rgb_triplane(self, plane): |
| x = plane.float() |
| if not hasattr(self, "colorize_triplane"): |
| self.colorize_triplane = torch.randn(3, x.shape[1], 1, 1).to(x) |
| x = torch.nn.functional.conv2d(x, weight=self.colorize_triplane) |
| x = ((x - x.min()) / (x.max() - x.min()) * 255.).permute(0, 2, 3, 1).detach().cpu().numpy().astype(np.uint8) |
| return x |
| |
| def to_rgb_3daware(self, plane): |
| x = plane.float() |
| if not hasattr(self, "colorize_3daware"): |
| self.colorize_3daware = torch.randn(3, x.shape[1], 1, 1).to(x) |
| x = torch.nn.functional.conv2d(x, weight=self.colorize_3daware) |
| x = ((x - x.min()) / (x.max() - x.min()) * 255.).permute(0, 2, 3, 1).detach().cpu().numpy().astype(np.uint8) |
| return x |
|
|
| def test_step(self, batch, batch_idx): |
| inputs = self.rollout(batch['triplane']) |
| reconstructions, posterior = self(inputs, sample_posterior=False) |
| aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, prefix='test/', batch=None) |
| self.log_dict(log_dict_ae) |
|
|
| batch_size = inputs.shape[0] |
| psnr_list = [] |
| psnr_input_list = [] |
| psnr_rec_list = [] |
|
|
| z = posterior.mode() |
| colorize_z = self.to_rgb(z)[0] |
| colorize_triplane_input = self.to_rgb_triplane(inputs)[0] |
| colorize_triplane_output = self.to_rgb_triplane(reconstructions)[0] |
|
|
| import os |
| import random |
| import string |
| |
| z_np = inputs.detach().cpu().numpy() |
| fname = ''.join(random.choices(string.ascii_uppercase + string.digits, k=8)) + '.npy' |
| with open(os.path.join('/mnt/lustre/hongfangzhou.p/AE3D/tmp', fname), 'wb') as f: |
| np.save(f, z_np) |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| np_z = z.detach().cpu().numpy() |
| |
| |
|
|
| self.latent_list.append(np_z) |
|
|
| reconstructions = self.unrollout(reconstructions) |
|
|
| if self.psum.device != z.device: |
| self.psum = self.psum.to(z.device) |
| self.psum_sq = self.psum_sq.to(z.device) |
| self.psum_min = self.psum_min.to(z.device) |
| self.psum_max = self.psum_max.to(z.device) |
| self.psum += z.sum() |
| self.psum_sq += (z ** 2).sum() |
| self.psum_min += z.reshape(-1).min(-1)[0] |
| self.psum_max += z.reshape(-1).max(-1)[0] |
| assert len(z.shape) == 4 |
| self.count += z.shape[0] * z.shape[1] * z.shape[2] * z.shape[3] |
| self.len_dset += 1 |
|
|
| if self.norm: |
| assert NotImplementedError |
| else: |
| reconstructions_unnormalize = reconstructions |
|
|
| if True: |
| for b in range(batch_size): |
| if self.renderer_type == 'nerf': |
| rgb_input, cur_psnr_list_input = self.render_triplane( |
| batch['triplane_ori'][b:b+1], batch['batch_rays'][b], batch['img_flat'][b], |
| batch['near'][b].unsqueeze(-1), batch['far'][b].unsqueeze(-1) |
| ) |
| rgb, cur_psnr_list = self.render_triplane( |
| reconstructions_unnormalize[b:b+1], batch['batch_rays'][b], batch['img_flat'][b], |
| batch['near'][b].unsqueeze(-1), batch['far'][b].unsqueeze(-1) |
| ) |
| elif self.renderer_type == 'eg3d': |
| rgb_input, cur_psnr_list_input = self.render_triplane_eg3d_decoder( |
| batch['triplane_ori'][b:b+1], batch['batch_rays'][b], batch['img'][b], |
| ) |
| rgb, cur_psnr_list = self.render_triplane_eg3d_decoder( |
| reconstructions_unnormalize[b:b+1], batch['batch_rays'][b], batch['img'][b], |
| ) |
| else: |
| raise NotImplementedError |
|
|
| cur_psnr_list_rec = [] |
| for i in range(rgb.shape[0]): |
| cur_psnr_list_rec.append(mse2psnr(img2mse(rgb_input[i], rgb[i]))) |
|
|
| rgb_input = to8b(rgb_input.detach().cpu().numpy()) |
| rgb_gt = to8b(batch['img'][b].detach().cpu().numpy()) |
| rgb = to8b(rgb.detach().cpu().numpy()) |
| |
| |
| |
| |
| |
|
|
| psnr_list += cur_psnr_list |
| psnr_input_list += cur_psnr_list_input |
| psnr_rec_list += cur_psnr_list_rec |
|
|
| self.log("test/psnr_input_gt", torch.Tensor(psnr_input_list).mean(), prog_bar=True) |
| self.log("test/psnr_input_rec", torch.Tensor(psnr_rec_list).mean(), prog_bar=True) |
| self.log("test/psnr_rec_gt", torch.Tensor(psnr_list).mean(), prog_bar=True) |
|
|
| def on_test_epoch_end(self): |
| mean = self.psum / self.count |
| mean_min = self.psum_min / self.len_dset |
| mean_max = self.psum_max / self.len_dset |
| var = (self.psum_sq / self.count) - (mean ** 2) |
| std = torch.sqrt(var) |
|
|
| print("mean min: {}".format(mean_min)) |
| print("mean max: {}".format(mean_max)) |
| print("mean: {}".format(mean)) |
| print("std: {}".format(std)) |
|
|
| latent = np.concatenate(self.latent_list) |
| q75, q25 = np.percentile(latent.reshape(-1), [75 ,25]) |
| median = np.median(latent.reshape(-1)) |
| iqr = q75 - q25 |
| norm_iqr = iqr * 0.7413 |
| print("Norm IQR: {}".format(norm_iqr)) |
| print("Inverse Norm IQR: {}".format(1/norm_iqr)) |
| print("Median: {}".format(median)) |
|
|