|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Loss functions used in the paper |
|
|
"Elucidating the Design Space of Diffusion-Based Generative Models".""" |
|
|
|
|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
|
|
|
from utils import * |
|
|
from train_utils.helper import unwrap_model |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class EDMLoss: |
|
|
def __init__(self, P_mean=-1.2, P_std=1.2, sigma_data=0.5): |
|
|
self.P_mean = P_mean |
|
|
self.P_std = P_std |
|
|
self.sigma_data = sigma_data |
|
|
|
|
|
def __call__(self, net, |
|
|
images, |
|
|
labels=None, |
|
|
mask_ratio=0, |
|
|
mae_loss_coef=0, |
|
|
feat=None, augment_pipe=None): |
|
|
|
|
|
rnd_normal = torch.randn([images.shape[0], 1, 1, 1], device=images.device) |
|
|
sigma = (rnd_normal * self.P_std + self.P_mean).exp() |
|
|
weight = (sigma ** 2 + self.sigma_data ** 2) / (sigma * self.sigma_data) ** 2 |
|
|
y, augment_labels = augment_pipe(images) if augment_pipe is not None else (images, None) |
|
|
n = torch.randn_like(y) * sigma |
|
|
|
|
|
model_out = net(y + n, sigma, labels, mask_ratio=mask_ratio, mask_dict=None, feat=feat) |
|
|
D_yn = model_out['x'] |
|
|
assert D_yn.shape == y.shape |
|
|
loss = weight * ((D_yn - y) ** 2) |
|
|
if mask_ratio > 0: |
|
|
assert net.training and 'mask' in model_out |
|
|
loss = F.avg_pool2d(loss.mean(dim=1), net.module.model.patch_size).flatten(1) |
|
|
unmask = 1 - model_out['mask'] |
|
|
loss = (loss * unmask).sum(dim=1) / unmask.sum(dim=1) |
|
|
assert loss.ndim == 1 |
|
|
if mae_loss_coef > 0: |
|
|
loss += mae_loss_coef * mae_loss(net.module, y + n, D_yn, 1 - unmask) |
|
|
else: |
|
|
loss = mean_flat(loss) |
|
|
|
|
|
raw_net = unwrap_model(net) |
|
|
if mask_ratio == 0.0 and raw_net.model.mask_token is not None: |
|
|
loss += 0 * torch.sum(raw_net.model.mask_token) |
|
|
assert loss.ndim == 1 |
|
|
return loss |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Losses = { |
|
|
'edm': EDMLoss |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def patchify(imgs, patch_size=2, num_channels=4): |
|
|
""" |
|
|
imgs: (N, 3, H, W) |
|
|
x: (N, L, patch_size**2 *3) |
|
|
""" |
|
|
p, c = patch_size, num_channels |
|
|
assert imgs.shape[2] == imgs.shape[3] and imgs.shape[2] % p == 0 |
|
|
|
|
|
h = w = imgs.shape[2] // p |
|
|
x = imgs.reshape(shape=(imgs.shape[0], c, h, p, w, p)) |
|
|
x = torch.einsum('nchpwq->nhwpqc', x) |
|
|
x = x.reshape(shape=(imgs.shape[0], h * w, p**2 * c)) |
|
|
return x |
|
|
|
|
|
|
|
|
def mae_loss(net, target, pred, mask, norm_pix_loss=True): |
|
|
target = patchify(target, net.model.patch_size, net.model.out_channels) |
|
|
pred = patchify(pred, net.model.patch_size, net.model.out_channels) |
|
|
if norm_pix_loss: |
|
|
mean = target.mean(dim=-1, keepdim=True) |
|
|
var = target.var(dim=-1, keepdim=True) |
|
|
target = (target - mean) / (var + 1.e-6)**.5 |
|
|
|
|
|
loss = (pred - target) ** 2 |
|
|
loss = loss.mean(dim=-1) |
|
|
|
|
|
loss = (loss * mask).sum(dim=1) / mask.sum(dim=1) |
|
|
assert loss.ndim == 1 |
|
|
return loss |
|
|
|