Spaces:
Sleeping
Sleeping
| import os | |
| import random | |
| import torch | |
| import torch.distributed as dist | |
| import torch_fidelity | |
| import torchvision | |
| from PIL import Image | |
| from torchvision import transforms | |
| import utils | |
| from utils.geometry import make_coord_scale_grid | |
| from .trainers import register | |
| from trainers.base_trainer import BaseTrainer | |
| from models.ldm.dac.audiotools import AudioSignal | |
| import soundfile as sf | |
| import numpy as np | |
| from models.ldm.dac.loss import (GANLoss, L1Loss, MelSpectrogramLoss, | |
| MultiScaleSTFTLoss, kl_loss) | |
| class LDMTrainer(BaseTrainer): | |
| def make_model(self): | |
| super().make_model() | |
| self.has_optimizer = dict() | |
| for name, m in self.model.named_children(): | |
| self.log(f' .{name} {utils.compute_num_params(m)}') | |
| def make_optimizers(self): | |
| self.optimizers = dict() | |
| self.has_optimizer = dict() | |
| for name, spec in self.config.optimizers.items(): | |
| self.optimizers[name] = utils.make_optimizer(self.model.get_parameters(name), spec) | |
| self.has_optimizer[name] = True | |
| def train_step(self, data, bp=True): | |
| kwargs = {'has_optimizer': self.has_optimizer} | |
| print('data', data.keys()) | |
| print('inp', data['inp'].shape) | |
| print('gt', data['gt'].shape) | |
| if self.config.get('autocast_bfloat16', False): | |
| with torch.cuda.amp.autocast(dtype=torch.bfloat16): | |
| ret = self.model_ddp(data, mode='loss', **kwargs) | |
| else: | |
| ret = self.model_ddp(data, mode='loss', **kwargs) | |
| loss = ret.pop('loss') | |
| ret['loss'] = loss.item() | |
| if bp: | |
| self.model_ddp.zero_grad() | |
| loss.backward() | |
| for name, o in self.optimizers.items(): | |
| if name != 'disc': | |
| o.step() | |
| self.model.update_ema() | |
| return ret | |
| def evaluate(self): | |
| self.model_ddp.eval() | |
| ave_scalars = dict() | |
| pbar = self.loaders['val'] | |
| for data in pbar: | |
| for k, v in data.items(): | |
| data[k] = v.to(self.device) if torch.is_tensor(v) else v | |
| ret = self.train_step(data, bp=False) | |
| bs = len(next(iter(data.values()))) | |
| for k, v in ret.items(): | |
| if ave_scalars.get(k) is None: | |
| ave_scalars[k] = utils.Averager() | |
| ave_scalars[k].add(v, n=bs) | |
| self.sync_ave_scalars(ave_scalars) | |
| # Extra evaluation # | |
| if self.config.get('evaluate_ae', False): | |
| ave_scalars.update(self.evaluate_ae()) | |
| if self.config.get('evaluate_zdm', False): | |
| ema = self.config.get('evaluate_zdm_ema', True) | |
| ave_scalars.update(self.evaluate_zdm(ema=ema)) | |
| # - # | |
| logtext = 'val:' | |
| for k, v in ave_scalars.items(): | |
| logtext += f' {k}={v.item():.4f}' | |
| self.log_scalar('val/' + k, v.item()) | |
| self.log_buffer.append(logtext) | |
| return ave_scalars | |
| def visualize(self): | |
| self.model_ddp.eval() | |
| if self.config.get('evaluate_ae', False): | |
| # self.visualize_ae_fixset() | |
| self.visualize_ae_random() | |
| if self.config.get('evaluate_zdm', False): | |
| ema = self.config.get('evaluate_zdm_ema', True) | |
| # self.visualize_zdm_fixset(ema=ema) | |
| self.visualize_zdm_random(ema=ema) | |
| # self.visualize_zdm_denoising(ema=ema) | |
| def evaluate_ae(self): | |
| max_samples = self.config.get('eval_ae_max_samples') | |
| self.loader_samplers['eval_ae'].set_epoch(0) | |
| to_pil = transforms.ToPILImage() | |
| psnr_value = utils.Averager() | |
| cnt = 0 | |
| cache_gen_dir = os.path.join(self.env['save_dir'], 'cache', 'fid_gen') | |
| cache_gt_dir = os.path.join(self.env['save_dir'], 'cache', 'fid_gt') | |
| if self.is_master: | |
| utils.ensure_path(cache_gen_dir, force_replace=True) | |
| utils.ensure_path(cache_gt_dir, force_replace=True) | |
| dist.barrier() | |
| for data in self.loaders['eval_ae']: | |
| for k, v in data.items(): | |
| data[k] = v.to(self.device) if torch.is_tensor(v) else v | |
| pred = self.model(data, mode='pred') | |
| gt_patch = data['gt'][:, :3, ...] | |
| pred = (pred * 0.5 + 0.5).clamp(0, 1) | |
| gt_patch = (gt_patch * 0.5 + 0.5).clamp(0, 1) | |
| # PSNR | |
| mse = (pred - gt_patch).pow(2).mean(dim=[1, 2, 3]) | |
| psnr_value.add((-10 * torch.log10(mse)).mean().item()) | |
| # FID | |
| for i in range(len(pred)): | |
| idx = int(os.environ['RANK']) + cnt * int(os.environ['WORLD_SIZE']) | |
| if max_samples is None or idx < max_samples: | |
| to_pil(pred[i]).save(os.path.join(cache_gen_dir, f'{idx}.png')) | |
| to_pil(gt_patch[i]).save(os.path.join(cache_gt_dir, f'{idx}.png')) | |
| cnt += 1 | |
| dist.barrier() | |
| vt = torch.tensor(psnr_value.item(), device=self.device) | |
| dist.all_reduce(vt, op=dist.ReduceOp.SUM) | |
| torch.cuda.synchronize() | |
| psnr_value = vt.item() / int(os.environ['WORLD_SIZE']) | |
| if self.is_master: | |
| metrics = torch_fidelity.calculate_metrics( | |
| input1=cache_gen_dir, | |
| input2=cache_gt_dir, | |
| cuda=True, | |
| fid=True, | |
| verbose=False, | |
| ) | |
| prefix = 'eval_ae' | |
| ret = { | |
| f'{prefix}/PSNR': psnr_value, | |
| f'{prefix}/FID': metrics['frechet_inception_distance'], | |
| } | |
| else: | |
| ret = {} | |
| dist.barrier() | |
| ret = {k: utils.Averager(v) for k, v in ret.items()} | |
| return ret | |
| def evaluate_zdm(self, ema): | |
| max_samples = self.config.get('eval_zdm_max_samples') | |
| self.loader_samplers['eval_zdm'].set_epoch(0) | |
| to_pil = transforms.ToPILImage() | |
| cnt = 0 | |
| cache_gen_dir = os.path.join(self.env['save_dir'], 'cache', 'fid_gen') | |
| cache_gt_dir = os.path.join(self.env['save_dir'], 'cache', 'fid_gt') | |
| if self.is_master: | |
| utils.ensure_path(cache_gen_dir, force_replace=True) | |
| utils.ensure_path(cache_gt_dir, force_replace=True) | |
| dist.barrier() | |
| for data in self.loaders['eval_zdm']: | |
| for k, v in data.items(): | |
| data[k] = v.to(self.device) if torch.is_tensor(v) else v | |
| gt_patch = data['inp'] | |
| net_kwargs = dict() | |
| uncond_net_kwargs = dict() | |
| if self.model.zdm_class_cond is not None: | |
| net_kwargs['class_labels'] = data['class_labels'] | |
| setting = self.config['visualize_zdm_setting'] | |
| uncond_net_kwargs['class_labels'] = setting['n_classes'] * torch.ones( | |
| len(data['class_labels']), dtype=torch.long, device=self.device) | |
| pred = self.model.generate_samples( | |
| batch_size=gt_patch.shape[0], | |
| n_steps=self.model.zdm_n_steps, | |
| net_kwargs=net_kwargs, | |
| uncond_net_kwargs=uncond_net_kwargs, | |
| ema=ema | |
| ) | |
| pred = (pred * 0.5 + 0.5).clamp(0, 1) | |
| gt_patch = (gt_patch * 0.5 + 0.5).clamp(0, 1) | |
| # FID | |
| for i in range(len(pred)): | |
| idx = int(os.environ['RANK']) + cnt * int(os.environ['WORLD_SIZE']) | |
| if max_samples is None or idx < max_samples: | |
| to_pil(pred[i]).save(os.path.join(cache_gen_dir, f'{idx}.png')) | |
| to_pil(gt_patch[i]).save(os.path.join(cache_gt_dir, f'{idx}.png')) | |
| cnt += 1 | |
| dist.barrier() | |
| if self.is_master: | |
| metrics = torch_fidelity.calculate_metrics( | |
| input1=cache_gen_dir, | |
| input2=cache_gt_dir, | |
| cuda=True, | |
| fid=True, | |
| verbose=False, | |
| ) | |
| prefix = 'eval_zdm' + ('_ema' if ema else '') | |
| ret = { | |
| f'{prefix}/FID': metrics['frechet_inception_distance'], | |
| } | |
| else: | |
| ret = {} | |
| dist.barrier() | |
| ret = {k: utils.Averager(v) for k, v in ret.items()} | |
| return ret | |
| def visualize_ae_fixset(self): | |
| if self.config.get('visualize_ae_dir') is None: | |
| return | |
| to_tensor = transforms.ToTensor() | |
| if self.is_master: | |
| files = sorted(os.listdir(self.config['visualize_ae_dir'])) | |
| vis_images = [] | |
| for f in files: | |
| image = Image.open(os.path.join(self.config['visualize_ae_dir'], f)).convert('RGB') | |
| x = to_tensor(image).unsqueeze(0).to(self.device) | |
| x = (x - 0.5) / 0.5 | |
| gt_dummy = torch.zeros(x.shape[0], 7, x.shape[2], x.shape[3], device=self.device) | |
| pred1 = self.model({'inp': x, 'gt': gt_dummy}, mode='pred') | |
| pred2 = self.model({'inp': x, 'gt': gt_dummy}, mode='pred') | |
| vis_images.extend([x, pred1, pred2]) | |
| vis_images = torch.cat(vis_images, dim=0) | |
| vis_images = torchvision.utils.make_grid(vis_images, normalize=True, value_range=(-1, 1), nrow=6) | |
| self.log_image('vis_ae_fixset', vis_images) | |
| dist.barrier() | |
| def visualize_ae_random(self): | |
| if self.is_master: | |
| idx_list = list(range(len(self.datasets['eval_ae']))) | |
| random.shuffle(idx_list) | |
| n_samples = self.config['visualize_ae_random_n_samples'] | |
| vis_images = [] | |
| for idx in idx_list[:n_samples]: | |
| data = self.datasets['eval_ae'][idx] | |
| for k, v in data.items(): | |
| data[k] = v.unsqueeze(0).to(self.device) if torch.is_tensor(v) else v | |
| pred1 = self.model(data, mode='pred') | |
| pred2 = self.model(data, mode='pred') | |
| gt_patch = data['gt'][:, :3, ...] | |
| vis_images.extend([gt_patch, pred1, pred2]) | |
| vis_images = torch.cat(vis_images, dim=0) | |
| vis_images = torchvision.utils.make_grid(vis_images, normalize=True, value_range=(-1, 1), nrow=6) | |
| self.log_image('vis_ae_random', vis_images) | |
| dist.barrier() | |
| def visualize_zdm_fixset(self, ema): | |
| if self.is_master: | |
| vis_file = torch.load(self.config['visualize_zdm_file'], map_location='cpu') | |
| for k, v in vis_file.items(): | |
| vis_file[k] = v.to(self.device) if torch.is_tensor(v) else v | |
| n_samples = len(vis_file['noise']) | |
| batch_size = self.config.get('visualize_zdm_batch_size', 1) | |
| guidance_list = [1.0] + self.config.get('visualize_zdm_guidance_list', []) | |
| vis_images = [] | |
| for i in range(0, n_samples, batch_size): | |
| cur_batch_size = min(batch_size, n_samples - i) | |
| net_kwargs = dict() | |
| uncond_net_kwargs = dict() | |
| if self.config.get('visualize_zdm_setting') is not None: | |
| setting = self.config['visualize_zdm_setting'] | |
| if setting['name'] == 'class': | |
| net_kwargs['class_labels'] = vis_file['class_labels'][i:i + cur_batch_size] | |
| uncond_net_kwargs['class_labels'] = setting['n_classes'] * torch.ones( | |
| cur_batch_size, dtype=torch.long, device=self.device) | |
| else: | |
| raise NotImplementedError | |
| for guidance in guidance_list: | |
| pred = self.model.generate_samples( | |
| batch_size=cur_batch_size, | |
| n_steps=self.model.zdm_n_steps, | |
| net_kwargs=net_kwargs, | |
| uncond_net_kwargs=uncond_net_kwargs, | |
| ema=ema, | |
| guidance=guidance, | |
| noise=vis_file['noise'][i:i + cur_batch_size], | |
| ) | |
| vis_images.append(pred) | |
| vis_images = torch.cat(vis_images, dim=0) | |
| vis_images = torchvision.utils.make_grid(vis_images, normalize=True, value_range=(-1, 1), nrow=batch_size) | |
| name = 'vis_zdm_fixset' | |
| name += '_ema' if ema else '' | |
| name += '_cfg' + str(guidance_list[1:])[1:-1] if len(guidance_list) > 1 else '' | |
| self.log_image(name, vis_images) | |
| dist.barrier() | |
| def visualize_zdm_random(self, ema): | |
| n_samples = self.config['visualize_zdm_random_n_samples'] | |
| batch_size = self.config.get('visualize_zdm_batch_size', 1) | |
| guidance_list = [1.0] + self.config.get('visualize_zdm_guidance_list', []) | |
| vis_images = [] | |
| if self.is_master: | |
| for i in range(0, n_samples, batch_size): | |
| cur_batch_size = min(batch_size, n_samples - i) | |
| net_kwargs = dict() | |
| uncond_net_kwargs = dict() | |
| if self.config.get('visualize_zdm_setting') is not None: | |
| setting = self.config['visualize_zdm_setting'] | |
| if setting['name'] == 'class': | |
| net_kwargs['class_labels'] = torch.randint( | |
| setting['n_classes'], size=(cur_batch_size,), device=self.device) | |
| uncond_net_kwargs['class_labels'] = setting['n_classes'] * torch.ones( | |
| cur_batch_size, dtype=torch.long, device=self.device) | |
| else: | |
| raise NotImplementedError | |
| for guidance in guidance_list: | |
| pred = self.model.generate_samples( | |
| batch_size=cur_batch_size, | |
| n_steps=self.model.zdm_n_steps, | |
| net_kwargs=net_kwargs, | |
| uncond_net_kwargs=uncond_net_kwargs, | |
| ema=ema, | |
| guidance=guidance, | |
| ) | |
| vis_images.append(pred) | |
| vis_images = torch.cat(vis_images, dim=0) | |
| vis_images = torchvision.utils.make_grid(vis_images, normalize=True, value_range=(-1, 1), nrow=batch_size) | |
| name = 'vis_zdm_random' | |
| name += '_ema' if ema else '' | |
| name += '_cfg' + str(guidance_list[1:])[1:-1] if len(guidance_list) > 1 else '' | |
| self.log_image(name, vis_images) | |
| dist.barrier() | |
| def visualize_zdm_denoising(self, ema, n_selected_timesteps=5): | |
| if self.is_master: | |
| vis_file = torch.load(self.config['visualize_zdm_denoising_file'], map_location='cpu') | |
| vis_images = [] | |
| for i in range(len(vis_file['inp'])): | |
| x = ( | |
| vis_file['inp'][i] | |
| .to(self.device) | |
| .unsqueeze(0) | |
| .expand(n_selected_timesteps, -1, -1, -1) | |
| ) | |
| z = self.model.encode(x) | |
| z = self.model.normalize_for_zdm(z) | |
| t = torch.linspace(0, 1, n_selected_timesteps + 1, device=self.device)[1:] | |
| noise = ( | |
| vis_file['noise'][i] | |
| .to(self.device) | |
| .unsqueeze(0) | |
| .expand(n_selected_timesteps, -1, -1, -1) | |
| ) | |
| z_t, _ = self.model.zdm_diffusion.add_noise(z, t, noise=noise) | |
| # Visualize noisy latents | |
| zp = self.model.denormalize_for_zdm(z_t) | |
| z_dec = self.model.decode(zp) | |
| coord, scale = make_coord_scale_grid(x.shape[-2:], device=self.device, batch_size=n_selected_timesteps) | |
| coord = coord.permute(0, 3, 1, 2) | |
| scale = scale.permute(0, 3, 1, 2) | |
| x_out = self.model.render(z_dec, coord, scale) | |
| vis_images.append(x_out) | |
| # Generate denoised latents | |
| net = self.model.zdm_net_ema if ema else self.model.zdm_net | |
| net_kwargs = dict() | |
| if self.config.get('visualize_zdm_setting') is not None: | |
| setting = self.config['visualize_zdm_setting'] | |
| if setting['name'] == 'class': | |
| net_kwargs['class_labels'] = ( | |
| vis_file['class_labels'][i] | |
| .to(self.device) | |
| .unsqueeze(0) | |
| .expand(n_selected_timesteps) | |
| ) | |
| else: | |
| raise NotImplementedError | |
| pred = self.model.zdm_diffusion.get_prediction(net, z_t, t, net_kwargs=net_kwargs) | |
| zp = [] | |
| for j in range(len(pred)): | |
| zp.append(self.model.zdm_diffusion.convert_sample_prediction(z_t[j], float(t[j]), pred[j])) | |
| zp = torch.stack(zp, dim=0) | |
| # Visualize denoised latents | |
| zp = self.model.denormalize_for_zdm(zp) | |
| z_dec = self.model.decode(zp) | |
| coord, scale = make_coord_scale_grid(x.shape[-2:], device=self.device, batch_size=n_selected_timesteps) | |
| coord = coord.permute(0, 3, 1, 2) | |
| scale = scale.permute(0, 3, 1, 2) | |
| x_out = self.model.render(z_dec, coord, scale) | |
| vis_images.append(x_out) | |
| vis_images = torch.cat(vis_images, dim=0) | |
| vis_images = torchvision.utils.make_grid(vis_images, normalize=True, value_range=(-1, 1), nrow=n_selected_timesteps) | |
| self.log_image('vis_zdm' + ('_ema' if ema else '') + '_denoising', vis_images) | |
| dist.barrier() | |