Spaces:
Sleeping
Sleeping
| import copy | |
| import math | |
| import torch | |
| import models | |
| from omegaconf import OmegaConf | |
| from models import register | |
| from models.ldm.ldm_base import LDMBase, LDMBaseAudio | |
| from models.ldm.vqgan.lpips import LPIPS | |
| class DiTo(LDMBase): | |
| def __init__(self, render_diffusion, render_sampler, render_n_steps, renderer_guidance=1, lpips=False, **kwargs): | |
| super().__init__(**kwargs) | |
| self.render_diffusion = models.make(render_diffusion) | |
| if OmegaConf.is_config(render_sampler): | |
| render_sampler = OmegaConf.to_container(render_sampler, resolve=True) | |
| render_sampler = copy.deepcopy(render_sampler) | |
| if render_sampler.get('args') is None: | |
| render_sampler['args'] = {} | |
| render_sampler['args']['diffusion'] = self.render_diffusion | |
| self.render_sampler = models.make(render_sampler) | |
| self.render_n_steps = render_n_steps | |
| self.renderer_guidance = renderer_guidance | |
| self.t_loss_monitor_v = [0 for _ in range(10)] | |
| self.t_loss_monitor_n = [0 for _ in range(10)] | |
| self.t_loss_monitor_decay = 0.99 | |
| self.use_lpips = lpips | |
| if lpips: | |
| self.lpips_loss = LPIPS().eval() | |
| def render(self, z_dec, coord, scale): | |
| shape = (coord.size(0), 3, coord.size(2), coord.size(3)) | |
| net_kwargs = {'coord': coord, 'scale': scale, 'z_dec': z_dec} | |
| if self.use_ema_renderer: | |
| self.swap_ema_renderer() | |
| if self.renderer_guidance > 1: | |
| uncond_z_dec = self.drop_z_emb.unsqueeze(0).expand(z_dec.shape[0], -1, -1, -1) | |
| uncond_net_kwargs = {'coord': coord, 'scale': scale, 'z_dec': uncond_z_dec} | |
| else: | |
| uncond_net_kwargs = None | |
| ret = self.render_sampler.sample( | |
| net=self.renderer, | |
| shape=shape, | |
| n_steps=self.render_n_steps, | |
| net_kwargs=net_kwargs, | |
| uncond_net_kwargs=uncond_net_kwargs, | |
| guidance=self.renderer_guidance, | |
| ) | |
| if self.use_ema_renderer: | |
| self.swap_ema_renderer() | |
| return ret | |
| def forward(self, data, mode, has_optimizer=None): | |
| if mode in ['z', 'z_dec']: | |
| ret_z, _ = super().forward(data, mode=mode, has_optimizer=has_optimizer) | |
| return ret_z | |
| grad = self.get_grad_plan(has_optimizer) | |
| loss_config = self.loss_config | |
| print('mode', mode) | |
| if mode == 'pred': | |
| z_dec, ret = super().forward(data, mode='z_dec', has_optimizer=has_optimizer) | |
| gt_patch = data['gt'][:, :3, ...] | |
| coord = data['gt'][:, 3:5, ...] | |
| scale = data['gt'][:, 5:7, ...] | |
| if grad['renderer']: | |
| return self.render(z_dec, coord, scale) | |
| else: | |
| with torch.no_grad(): | |
| return self.render(z_dec, coord, scale) | |
| elif mode == 'loss': | |
| if not grad['renderer']: # Only training zdm | |
| print('not grad[renderer]') | |
| _, ret = super().forward(data, mode='z', has_optimizer=has_optimizer) | |
| return ret | |
| gt_patch = data['gt'][:, :3, ...] | |
| coord = data['gt'][:, 3:5, ...] | |
| scale = data['gt'][:, 5:7, ...] | |
| z_dec, ret = super().forward(data, mode='z_dec', has_optimizer=has_optimizer) | |
| net_kwargs = {'z_dec': z_dec} | |
| print('latent z_dec shape: ', z_dec.shape) | |
| t = torch.rand(gt_patch.shape[0], device=gt_patch.device) | |
| print('self.gt_noise_lb:', self.gt_noise_lb) | |
| if self.gt_noise_lb is not None: | |
| tmin = torch.ones_like(t) * self.gt_noise_lb | |
| tmax = torch.ones_like(t) * 1 | |
| t = tmin + (tmax - tmin) * torch.rand_like(tmin) | |
| print('self.zaug_p:', self.zaug_p) | |
| print('self.training:', self.training) | |
| if (self.zaug_p is not None) and self.training: | |
| tz = self._tz | |
| mask_aug = self._mask_aug | |
| typ = self.zaug_decoding_loss_type | |
| if typ == 'all': | |
| tmin = torch.ones_like(tz) * 0 | |
| tmax = torch.ones_like(tz) * 1 | |
| elif typ == 'suffix': | |
| tmin = tz | |
| tmax = torch.ones_like(tz) * 1 | |
| elif typ == 'tz': | |
| tmin = tz | |
| tmax = tz | |
| elif typ == 'tmax': | |
| tmin = torch.ones_like(tz) * 1 | |
| tmax = torch.ones_like(tz) * 1 | |
| else: | |
| raise NotImplementedError | |
| t_aug = tmin + (tmax - tmin) * torch.rand_like(tmin) | |
| t = mask_aug * t_aug + (1 - mask_aug) * t | |
| print('self.use_lpips:', self.use_lpips) | |
| if not self.use_lpips: | |
| loss, t = self.render_diffusion.loss( | |
| net=self.renderer, | |
| x=gt_patch, | |
| t=t, | |
| net_kwargs=net_kwargs, | |
| return_loss_unreduced=True | |
| ) | |
| else: | |
| loss, t, x_t, pred = self.render_diffusion.loss( | |
| net=self.renderer, | |
| x=gt_patch, | |
| t=t, | |
| net_kwargs=net_kwargs, | |
| return_loss_unreduced=True, | |
| return_all=True | |
| ) | |
| sample_pred = x_t + t.view(-1, 1, 1, 1) * pred | |
| lpips_loss = self.lpips_loss(sample_pred, gt_patch).mean() | |
| ret['lpips_loss'] = lpips_loss.item() | |
| lpips_loss_w = loss_config.get('lpips_loss', 1) | |
| ret['loss'] = ret['loss'] + lpips_loss * lpips_loss_w | |
| # Visualize diffusion network loss for different timesteps # | |
| if self.training: | |
| m = len(self.t_loss_monitor_v) | |
| for i in range(len(loss)): | |
| q = min(math.floor(t[i].item() * m), m - 1) | |
| self.t_loss_monitor_v[q] = self.t_loss_monitor_v[q] * self.t_loss_monitor_decay + loss[i].item() * (1 - self.t_loss_monitor_decay) | |
| self.t_loss_monitor_n[q] += 1 | |
| for q in range(m): | |
| if self.t_loss_monitor_n[q] > 0: | |
| if self.t_loss_monitor_n[q] < 500: | |
| r = 1 - math.pow(self.t_loss_monitor_decay, self.t_loss_monitor_n[q]) | |
| else: | |
| r = 1 | |
| ret[f'_loss_t{q}'] = self.t_loss_monitor_v[q] / r | |
| # - # | |
| dae_loss = loss.mean() | |
| ret['dae_loss'] = dae_loss.item() | |
| dae_loss_w = loss_config.get('dae_loss', 1) | |
| ret['loss'] = ret['loss'] + dae_loss * dae_loss_w | |
| return ret | |
| class DiToAudio(LDMBaseAudio): | |
| def __init__(self, render_diffusion, render_sampler, render_n_steps, renderer_guidance=1,**kwargs): | |
| super().__init__(**kwargs) | |
| self.render_diffusion = models.make(render_diffusion) | |
| if OmegaConf.is_config(render_sampler): | |
| render_sampler = OmegaConf.to_container(render_sampler, resolve=True) | |
| render_sampler = copy.deepcopy(render_sampler) | |
| if render_sampler.get('args') is None: | |
| render_sampler['args'] = {} | |
| render_sampler['args']['diffusion'] = self.render_diffusion | |
| self.render_sampler = models.make(render_sampler) | |
| self.render_n_steps = render_n_steps | |
| self.renderer_guidance = renderer_guidance | |
| self.t_loss_monitor_v = [0 for _ in range(10)] | |
| self.t_loss_monitor_n = [0 for _ in range(10)] | |
| self.t_loss_monitor_decay = 0.99 | |
| def render(self, z_dec): | |
| net_kwargs = {'z_dec': z_dec} | |
| n_frames = z_dec.size(2) * 320 | |
| shape = (z_dec.size(0), z_dec.size(0), n_frames) | |
| if self.renderer_guidance > 1: | |
| uncond_z_dec = self.drop_z_emb.unsqueeze(0).expand(z_dec.shape[0], -1, -1, -1) | |
| uncond_net_kwargs = {'z_dec': uncond_z_dec} | |
| else: | |
| uncond_net_kwargs = None | |
| ret = self.render_sampler.sample( | |
| net=self.renderer, | |
| n_steps=self.render_n_steps, | |
| shape=shape, | |
| net_kwargs=net_kwargs, | |
| uncond_net_kwargs=uncond_net_kwargs, | |
| guidance=self.renderer_guidance, | |
| ) | |
| # if self.use_ema_renderer: | |
| # self.swap_ema_renderer() | |
| return ret | |
| def forward(self, data, mode, has_optimizer=None): | |
| if mode in ['z', 'z_dec']: | |
| ret_z, _ = super().forward(data, mode=mode, has_optimizer=has_optimizer) | |
| return ret_z | |
| grad = self.get_grad_plan(has_optimizer) | |
| loss_config = self.loss_config | |
| if mode == 'pred': | |
| z_dec, ret = super().forward(data, mode='z_dec', has_optimizer=has_optimizer) | |
| gt_patch = data['gt'] | |
| if grad['renderer']: | |
| return self.render(z_dec) | |
| else: | |
| with torch.no_grad(): | |
| return self.render(z_dec) | |
| elif mode == 'loss': | |
| if not grad['renderer']: # Only training zdm | |
| _, ret = super().forward(data, mode='z', has_optimizer=has_optimizer) | |
| return ret | |
| gt_patch = data['gt'] | |
| z_dec, ret = super().forward(data, mode='z_dec', has_optimizer=has_optimizer) | |
| net_kwargs = {'z_dec': z_dec} | |
| # print('latent z_dec shape: ', z_dec.shape) | |
| t = torch.rand(gt_patch.shape[0], device=gt_patch.device) | |
| # print('self.zaug_p:', self.zaug_p) | |
| # print('self.training:', self.training) | |
| if (self.zaug_p is not None) and self.training: | |
| tz = self._tz | |
| mask_aug = self._mask_aug | |
| typ = self.zaug_decoding_loss_type | |
| if typ == 'all': | |
| tmin = torch.ones_like(tz) * 0 | |
| tmax = torch.ones_like(tz) * 1 | |
| elif typ == 'suffix': | |
| tmin = tz | |
| tmax = torch.ones_like(tz) * 1 | |
| elif typ == 'tz': | |
| tmin = tz | |
| tmax = tz | |
| elif typ == 'tmax': | |
| tmin = torch.ones_like(tz) * 1 | |
| tmax = torch.ones_like(tz) * 1 | |
| else: | |
| raise NotImplementedError | |
| t_aug = tmin + (tmax - tmin) * torch.rand_like(tmin) | |
| t = mask_aug * t_aug + (1 - mask_aug) * t | |
| loss, t = self.render_diffusion.loss( | |
| net=self.renderer, | |
| x=gt_patch, | |
| t=t, | |
| net_kwargs=net_kwargs, | |
| return_loss_unreduced=True | |
| ) | |
| # Visualize diffusion network loss for different timesteps # | |
| if self.training: | |
| m = len(self.t_loss_monitor_v) | |
| for i in range(len(loss)): | |
| q = min(math.floor(t[i].item() * m), m - 1) | |
| self.t_loss_monitor_v[q] = self.t_loss_monitor_v[q] * self.t_loss_monitor_decay + loss[i].item() * (1 - self.t_loss_monitor_decay) | |
| self.t_loss_monitor_n[q] += 1 | |
| for q in range(m): | |
| if self.t_loss_monitor_n[q] > 0: | |
| if self.t_loss_monitor_n[q] < 500: | |
| r = 1 - math.pow(self.t_loss_monitor_decay, self.t_loss_monitor_n[q]) | |
| else: | |
| r = 1 | |
| ret[f'_loss_t{q}'] = self.t_loss_monitor_v[q] / r | |
| # - # | |
| dae_loss = loss.mean() | |
| ret['dae_loss'] = dae_loss.item() | |
| dae_loss_w = loss_config.get('dae_loss', 1) | |
| ret['loss'] = ret['loss'] + dae_loss * dae_loss_w | |
| return ret |