# MIT License # Copyright (c) [2023] [Anima-Lab] # This code is adapted from https://github.com/NVlabs/edm/blob/main/training/loss.py. # The original code is licensed under a Creative Commons # Attribution-NonCommercial-ShareAlike 4.0 International License, which is can be found at licenses/LICENSE_EDM.txt. """Loss functions used in the paper "Elucidating the Design Space of Diffusion-Based Generative Models".""" import torch import torch.nn.functional as F from utils import * from train_utils.helper import unwrap_model # Improved loss function proposed in the paper "Elucidating the Design Space # of Diffusion-Based Generative Models" (EDM). class EDMLoss: def __init__(self, P_mean=-1.2, P_std=1.2, sigma_data=0.5): self.P_mean = P_mean self.P_std = P_std self.sigma_data = sigma_data def __call__(self, net, images, labels=None, mask_ratio=0, mae_loss_coef=0, feat=None, augment_pipe=None): # sample x_t rnd_normal = torch.randn([images.shape[0], 1, 1, 1], device=images.device) sigma = (rnd_normal * self.P_std + self.P_mean).exp() weight = (sigma ** 2 + self.sigma_data ** 2) / (sigma * self.sigma_data) ** 2 y, augment_labels = augment_pipe(images) if augment_pipe is not None else (images, None) n = torch.randn_like(y) * sigma model_out = net(y + n, sigma, labels, mask_ratio=mask_ratio, mask_dict=None, feat=feat) D_yn = model_out['x'] assert D_yn.shape == y.shape loss = weight * ((D_yn - y) ** 2) # (N, C, H, W) if mask_ratio > 0: assert net.training and 'mask' in model_out loss = F.avg_pool2d(loss.mean(dim=1), net.module.model.patch_size).flatten(1) # (N, L) unmask = 1 - model_out['mask'] loss = (loss * unmask).sum(dim=1) / unmask.sum(dim=1) # (N) assert loss.ndim == 1 if mae_loss_coef > 0: loss += mae_loss_coef * mae_loss(net.module, y + n, D_yn, 1 - unmask) else: loss = mean_flat(loss) # (N) raw_net = unwrap_model(net) if mask_ratio == 0.0 and raw_net.model.mask_token is not None: loss += 0 * torch.sum(raw_net.model.mask_token) assert loss.ndim == 1 return loss # ---------------------------------------------------------------------------- Losses = { 'edm': EDMLoss } # ---------------------------------------------------------------------------- def patchify(imgs, patch_size=2, num_channels=4): """ imgs: (N, 3, H, W) x: (N, L, patch_size**2 *3) """ p, c = patch_size, num_channels assert imgs.shape[2] == imgs.shape[3] and imgs.shape[2] % p == 0 h = w = imgs.shape[2] // p x = imgs.reshape(shape=(imgs.shape[0], c, h, p, w, p)) x = torch.einsum('nchpwq->nhwpqc', x) x = x.reshape(shape=(imgs.shape[0], h * w, p**2 * c)) return x def mae_loss(net, target, pred, mask, norm_pix_loss=True): target = patchify(target, net.model.patch_size, net.model.out_channels) pred = patchify(pred, net.model.patch_size, net.model.out_channels) if norm_pix_loss: mean = target.mean(dim=-1, keepdim=True) var = target.var(dim=-1, keepdim=True) target = (target - mean) / (var + 1.e-6)**.5 loss = (pred - target) ** 2 loss = loss.mean(dim=-1) # [N, L], mean loss per patch loss = (loss * mask).sum(dim=1) / mask.sum(dim=1) # mean loss on removed patches, (N) assert loss.ndim == 1 return loss