primepake
add training flowvae
4f877a2
raw
history blame
18 kB
import os
import random
import torch
import torch.distributed as dist
import torch_fidelity
import torchvision
from PIL import Image
from torchvision import transforms
import utils
from utils.geometry import make_coord_scale_grid
from .trainers import register
from trainers.base_trainer import BaseTrainer
from models.ldm.dac.audiotools import AudioSignal
import soundfile as sf
import numpy as np
from models.ldm.dac.loss import (GANLoss, L1Loss, MelSpectrogramLoss,
MultiScaleSTFTLoss, kl_loss)
@register('ldm_trainer')
class LDMTrainer(BaseTrainer):
def make_model(self):
super().make_model()
self.has_optimizer = dict()
for name, m in self.model.named_children():
self.log(f' .{name} {utils.compute_num_params(m)}')
def make_optimizers(self):
self.optimizers = dict()
self.has_optimizer = dict()
for name, spec in self.config.optimizers.items():
self.optimizers[name] = utils.make_optimizer(self.model.get_parameters(name), spec)
self.has_optimizer[name] = True
def train_step(self, data, bp=True):
kwargs = {'has_optimizer': self.has_optimizer}
print('data', data.keys())
print('inp', data['inp'].shape)
print('gt', data['gt'].shape)
if self.config.get('autocast_bfloat16', False):
with torch.cuda.amp.autocast(dtype=torch.bfloat16):
ret = self.model_ddp(data, mode='loss', **kwargs)
else:
ret = self.model_ddp(data, mode='loss', **kwargs)
loss = ret.pop('loss')
ret['loss'] = loss.item()
if bp:
self.model_ddp.zero_grad()
loss.backward()
for name, o in self.optimizers.items():
if name != 'disc':
o.step()
self.model.update_ema()
return ret
def evaluate(self):
self.model_ddp.eval()
ave_scalars = dict()
pbar = self.loaders['val']
for data in pbar:
for k, v in data.items():
data[k] = v.to(self.device) if torch.is_tensor(v) else v
ret = self.train_step(data, bp=False)
bs = len(next(iter(data.values())))
for k, v in ret.items():
if ave_scalars.get(k) is None:
ave_scalars[k] = utils.Averager()
ave_scalars[k].add(v, n=bs)
self.sync_ave_scalars(ave_scalars)
# Extra evaluation #
if self.config.get('evaluate_ae', False):
ave_scalars.update(self.evaluate_ae())
if self.config.get('evaluate_zdm', False):
ema = self.config.get('evaluate_zdm_ema', True)
ave_scalars.update(self.evaluate_zdm(ema=ema))
# - #
logtext = 'val:'
for k, v in ave_scalars.items():
logtext += f' {k}={v.item():.4f}'
self.log_scalar('val/' + k, v.item())
self.log_buffer.append(logtext)
return ave_scalars
def visualize(self):
self.model_ddp.eval()
if self.config.get('evaluate_ae', False):
# self.visualize_ae_fixset()
self.visualize_ae_random()
if self.config.get('evaluate_zdm', False):
ema = self.config.get('evaluate_zdm_ema', True)
# self.visualize_zdm_fixset(ema=ema)
self.visualize_zdm_random(ema=ema)
# self.visualize_zdm_denoising(ema=ema)
def evaluate_ae(self):
max_samples = self.config.get('eval_ae_max_samples')
self.loader_samplers['eval_ae'].set_epoch(0)
to_pil = transforms.ToPILImage()
psnr_value = utils.Averager()
cnt = 0
cache_gen_dir = os.path.join(self.env['save_dir'], 'cache', 'fid_gen')
cache_gt_dir = os.path.join(self.env['save_dir'], 'cache', 'fid_gt')
if self.is_master:
utils.ensure_path(cache_gen_dir, force_replace=True)
utils.ensure_path(cache_gt_dir, force_replace=True)
dist.barrier()
for data in self.loaders['eval_ae']:
for k, v in data.items():
data[k] = v.to(self.device) if torch.is_tensor(v) else v
pred = self.model(data, mode='pred')
gt_patch = data['gt'][:, :3, ...]
pred = (pred * 0.5 + 0.5).clamp(0, 1)
gt_patch = (gt_patch * 0.5 + 0.5).clamp(0, 1)
# PSNR
mse = (pred - gt_patch).pow(2).mean(dim=[1, 2, 3])
psnr_value.add((-10 * torch.log10(mse)).mean().item())
# FID
for i in range(len(pred)):
idx = int(os.environ['RANK']) + cnt * int(os.environ['WORLD_SIZE'])
if max_samples is None or idx < max_samples:
to_pil(pred[i]).save(os.path.join(cache_gen_dir, f'{idx}.png'))
to_pil(gt_patch[i]).save(os.path.join(cache_gt_dir, f'{idx}.png'))
cnt += 1
dist.barrier()
vt = torch.tensor(psnr_value.item(), device=self.device)
dist.all_reduce(vt, op=dist.ReduceOp.SUM)
torch.cuda.synchronize()
psnr_value = vt.item() / int(os.environ['WORLD_SIZE'])
if self.is_master:
metrics = torch_fidelity.calculate_metrics(
input1=cache_gen_dir,
input2=cache_gt_dir,
cuda=True,
fid=True,
verbose=False,
)
prefix = 'eval_ae'
ret = {
f'{prefix}/PSNR': psnr_value,
f'{prefix}/FID': metrics['frechet_inception_distance'],
}
else:
ret = {}
dist.barrier()
ret = {k: utils.Averager(v) for k, v in ret.items()}
return ret
def evaluate_zdm(self, ema):
max_samples = self.config.get('eval_zdm_max_samples')
self.loader_samplers['eval_zdm'].set_epoch(0)
to_pil = transforms.ToPILImage()
cnt = 0
cache_gen_dir = os.path.join(self.env['save_dir'], 'cache', 'fid_gen')
cache_gt_dir = os.path.join(self.env['save_dir'], 'cache', 'fid_gt')
if self.is_master:
utils.ensure_path(cache_gen_dir, force_replace=True)
utils.ensure_path(cache_gt_dir, force_replace=True)
dist.barrier()
for data in self.loaders['eval_zdm']:
for k, v in data.items():
data[k] = v.to(self.device) if torch.is_tensor(v) else v
gt_patch = data['inp']
net_kwargs = dict()
uncond_net_kwargs = dict()
if self.model.zdm_class_cond is not None:
net_kwargs['class_labels'] = data['class_labels']
setting = self.config['visualize_zdm_setting']
uncond_net_kwargs['class_labels'] = setting['n_classes'] * torch.ones(
len(data['class_labels']), dtype=torch.long, device=self.device)
pred = self.model.generate_samples(
batch_size=gt_patch.shape[0],
n_steps=self.model.zdm_n_steps,
net_kwargs=net_kwargs,
uncond_net_kwargs=uncond_net_kwargs,
ema=ema
)
pred = (pred * 0.5 + 0.5).clamp(0, 1)
gt_patch = (gt_patch * 0.5 + 0.5).clamp(0, 1)
# FID
for i in range(len(pred)):
idx = int(os.environ['RANK']) + cnt * int(os.environ['WORLD_SIZE'])
if max_samples is None or idx < max_samples:
to_pil(pred[i]).save(os.path.join(cache_gen_dir, f'{idx}.png'))
to_pil(gt_patch[i]).save(os.path.join(cache_gt_dir, f'{idx}.png'))
cnt += 1
dist.barrier()
if self.is_master:
metrics = torch_fidelity.calculate_metrics(
input1=cache_gen_dir,
input2=cache_gt_dir,
cuda=True,
fid=True,
verbose=False,
)
prefix = 'eval_zdm' + ('_ema' if ema else '')
ret = {
f'{prefix}/FID': metrics['frechet_inception_distance'],
}
else:
ret = {}
dist.barrier()
ret = {k: utils.Averager(v) for k, v in ret.items()}
return ret
def visualize_ae_fixset(self):
if self.config.get('visualize_ae_dir') is None:
return
to_tensor = transforms.ToTensor()
if self.is_master:
files = sorted(os.listdir(self.config['visualize_ae_dir']))
vis_images = []
for f in files:
image = Image.open(os.path.join(self.config['visualize_ae_dir'], f)).convert('RGB')
x = to_tensor(image).unsqueeze(0).to(self.device)
x = (x - 0.5) / 0.5
gt_dummy = torch.zeros(x.shape[0], 7, x.shape[2], x.shape[3], device=self.device)
pred1 = self.model({'inp': x, 'gt': gt_dummy}, mode='pred')
pred2 = self.model({'inp': x, 'gt': gt_dummy}, mode='pred')
vis_images.extend([x, pred1, pred2])
vis_images = torch.cat(vis_images, dim=0)
vis_images = torchvision.utils.make_grid(vis_images, normalize=True, value_range=(-1, 1), nrow=6)
self.log_image('vis_ae_fixset', vis_images)
dist.barrier()
def visualize_ae_random(self):
if self.is_master:
idx_list = list(range(len(self.datasets['eval_ae'])))
random.shuffle(idx_list)
n_samples = self.config['visualize_ae_random_n_samples']
vis_images = []
for idx in idx_list[:n_samples]:
data = self.datasets['eval_ae'][idx]
for k, v in data.items():
data[k] = v.unsqueeze(0).to(self.device) if torch.is_tensor(v) else v
pred1 = self.model(data, mode='pred')
pred2 = self.model(data, mode='pred')
gt_patch = data['gt'][:, :3, ...]
vis_images.extend([gt_patch, pred1, pred2])
vis_images = torch.cat(vis_images, dim=0)
vis_images = torchvision.utils.make_grid(vis_images, normalize=True, value_range=(-1, 1), nrow=6)
self.log_image('vis_ae_random', vis_images)
dist.barrier()
def visualize_zdm_fixset(self, ema):
if self.is_master:
vis_file = torch.load(self.config['visualize_zdm_file'], map_location='cpu')
for k, v in vis_file.items():
vis_file[k] = v.to(self.device) if torch.is_tensor(v) else v
n_samples = len(vis_file['noise'])
batch_size = self.config.get('visualize_zdm_batch_size', 1)
guidance_list = [1.0] + self.config.get('visualize_zdm_guidance_list', [])
vis_images = []
for i in range(0, n_samples, batch_size):
cur_batch_size = min(batch_size, n_samples - i)
net_kwargs = dict()
uncond_net_kwargs = dict()
if self.config.get('visualize_zdm_setting') is not None:
setting = self.config['visualize_zdm_setting']
if setting['name'] == 'class':
net_kwargs['class_labels'] = vis_file['class_labels'][i:i + cur_batch_size]
uncond_net_kwargs['class_labels'] = setting['n_classes'] * torch.ones(
cur_batch_size, dtype=torch.long, device=self.device)
else:
raise NotImplementedError
for guidance in guidance_list:
pred = self.model.generate_samples(
batch_size=cur_batch_size,
n_steps=self.model.zdm_n_steps,
net_kwargs=net_kwargs,
uncond_net_kwargs=uncond_net_kwargs,
ema=ema,
guidance=guidance,
noise=vis_file['noise'][i:i + cur_batch_size],
)
vis_images.append(pred)
vis_images = torch.cat(vis_images, dim=0)
vis_images = torchvision.utils.make_grid(vis_images, normalize=True, value_range=(-1, 1), nrow=batch_size)
name = 'vis_zdm_fixset'
name += '_ema' if ema else ''
name += '_cfg' + str(guidance_list[1:])[1:-1] if len(guidance_list) > 1 else ''
self.log_image(name, vis_images)
dist.barrier()
def visualize_zdm_random(self, ema):
n_samples = self.config['visualize_zdm_random_n_samples']
batch_size = self.config.get('visualize_zdm_batch_size', 1)
guidance_list = [1.0] + self.config.get('visualize_zdm_guidance_list', [])
vis_images = []
if self.is_master:
for i in range(0, n_samples, batch_size):
cur_batch_size = min(batch_size, n_samples - i)
net_kwargs = dict()
uncond_net_kwargs = dict()
if self.config.get('visualize_zdm_setting') is not None:
setting = self.config['visualize_zdm_setting']
if setting['name'] == 'class':
net_kwargs['class_labels'] = torch.randint(
setting['n_classes'], size=(cur_batch_size,), device=self.device)
uncond_net_kwargs['class_labels'] = setting['n_classes'] * torch.ones(
cur_batch_size, dtype=torch.long, device=self.device)
else:
raise NotImplementedError
for guidance in guidance_list:
pred = self.model.generate_samples(
batch_size=cur_batch_size,
n_steps=self.model.zdm_n_steps,
net_kwargs=net_kwargs,
uncond_net_kwargs=uncond_net_kwargs,
ema=ema,
guidance=guidance,
)
vis_images.append(pred)
vis_images = torch.cat(vis_images, dim=0)
vis_images = torchvision.utils.make_grid(vis_images, normalize=True, value_range=(-1, 1), nrow=batch_size)
name = 'vis_zdm_random'
name += '_ema' if ema else ''
name += '_cfg' + str(guidance_list[1:])[1:-1] if len(guidance_list) > 1 else ''
self.log_image(name, vis_images)
dist.barrier()
def visualize_zdm_denoising(self, ema, n_selected_timesteps=5):
if self.is_master:
vis_file = torch.load(self.config['visualize_zdm_denoising_file'], map_location='cpu')
vis_images = []
for i in range(len(vis_file['inp'])):
x = (
vis_file['inp'][i]
.to(self.device)
.unsqueeze(0)
.expand(n_selected_timesteps, -1, -1, -1)
)
z = self.model.encode(x)
z = self.model.normalize_for_zdm(z)
t = torch.linspace(0, 1, n_selected_timesteps + 1, device=self.device)[1:]
noise = (
vis_file['noise'][i]
.to(self.device)
.unsqueeze(0)
.expand(n_selected_timesteps, -1, -1, -1)
)
z_t, _ = self.model.zdm_diffusion.add_noise(z, t, noise=noise)
# Visualize noisy latents
zp = self.model.denormalize_for_zdm(z_t)
z_dec = self.model.decode(zp)
coord, scale = make_coord_scale_grid(x.shape[-2:], device=self.device, batch_size=n_selected_timesteps)
coord = coord.permute(0, 3, 1, 2)
scale = scale.permute(0, 3, 1, 2)
x_out = self.model.render(z_dec, coord, scale)
vis_images.append(x_out)
# Generate denoised latents
net = self.model.zdm_net_ema if ema else self.model.zdm_net
net_kwargs = dict()
if self.config.get('visualize_zdm_setting') is not None:
setting = self.config['visualize_zdm_setting']
if setting['name'] == 'class':
net_kwargs['class_labels'] = (
vis_file['class_labels'][i]
.to(self.device)
.unsqueeze(0)
.expand(n_selected_timesteps)
)
else:
raise NotImplementedError
pred = self.model.zdm_diffusion.get_prediction(net, z_t, t, net_kwargs=net_kwargs)
zp = []
for j in range(len(pred)):
zp.append(self.model.zdm_diffusion.convert_sample_prediction(z_t[j], float(t[j]), pred[j]))
zp = torch.stack(zp, dim=0)
# Visualize denoised latents
zp = self.model.denormalize_for_zdm(zp)
z_dec = self.model.decode(zp)
coord, scale = make_coord_scale_grid(x.shape[-2:], device=self.device, batch_size=n_selected_timesteps)
coord = coord.permute(0, 3, 1, 2)
scale = scale.permute(0, 3, 1, 2)
x_out = self.model.render(z_dec, coord, scale)
vis_images.append(x_out)
vis_images = torch.cat(vis_images, dim=0)
vis_images = torchvision.utils.make_grid(vis_images, normalize=True, value_range=(-1, 1), nrow=n_selected_timesteps)
self.log_image('vis_zdm' + ('_ema' if ema else '') + '_denoising', vis_images)
dist.barrier()