File size: 3,961 Bytes
e321b92
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# This work is licensed under a Creative Commons
# Attribution-NonCommercial-ShareAlike 4.0 International License.
# You should have received a copy of the license along with this
# work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/

"""Loss functions used in the paper
"Elucidating the Design Space of Diffusion-Based Generative Models"."""

import torch
from edm.torch_utils import persistence
import pdb
#----------------------------------------------------------------------------
# Loss function corresponding to the variance preserving (VP) formulation
# from the paper "Score-Based Generative Modeling through Stochastic
# Differential Equations".

@persistence.persistent_class
class VPLoss:
    def __init__(self, beta_d=19.9, beta_min=0.1, epsilon_t=1e-5):
        self.beta_d = beta_d
        self.beta_min = beta_min
        self.epsilon_t = epsilon_t

    def noise_and_weight(self, shape, device, sds=False):
        rnd_uniform = torch.rand([shape, 1, 1, 1], device=device)
        if sds:
            rnd_uniform = 0.02 + rnd_uniform*0.96 #Between O.O2 and 0.98, see https://github.com/ashawkey/stable-dreamfusion/blob/5550b91862a3af7842bb04875b7f1211e5095a63/guidance/sd_utils.py#L180
        sigma = self.sigma(1 + rnd_uniform * (self.epsilon_t - 1))
        weight = 1 / sigma ** 2
        return sigma, weight

    def __call__(self, net, x, latents, augment_pipe=None):
        sigma, weight = self.noise_and_weight(x.shape[0], x.device)
        n = torch.randn_like(x) * sigma
        D_xn = net(x + n, sigma, latents)
        loss = weight * ((D_xn - x) ** 2)
        return loss

    def sigma(self, t):
        t = torch.as_tensor(t)
        return ((0.5 * self.beta_d * (t ** 2) + self.beta_min * t).exp() - 1).sqrt()

#----------------------------------------------------------------------------
# Loss function corresponding to the variance exploding (VE) formulation
# from the paper "Score-Based Generative Modeling through Stochastic
# Differential Equations".

@persistence.persistent_class
class VELoss:
    def __init__(self, sigma_min=0.02, sigma_max=100):
        self.sigma_min = sigma_min
        self.sigma_max = sigma_max

    def noise_and_weight(self, shape, device, sds=False):
        rnd_uniform = torch.rand([x.shape[0], 1], device=x.device)
        sigma = self.sigma_min * ((self.sigma_max / self.sigma_min) ** rnd_uniform)
        weight = 1 / sigma ** 2
        return sigma, weight

    def __call__(self, net, x, latents, augment_pipe=None):
        sigma, weight = self.noise_and_weight(x.shape[0], x.device)
        n = torch.randn_like(x) * sigma
        D_xn = net(x + n, sigma, latents)
        loss = weight * ((D_xn - x) ** 2)
        return loss

#----------------------------------------------------------------------------
# Improved loss function proposed in the paper "Elucidating the Design Space
# of Diffusion-Based Generative Models" (EDM).

@persistence.persistent_class
class EDMLoss:
    def __init__(self, P_mean=-1.2, P_std=1.2, sigma_data=0.5):
        self.P_mean = P_mean
        self.P_std = P_std
        self.sigma_data = sigma_data
        self.sigma_min = 0.4
        self.sigma_max = 10
        self.rho=3
        
    def noise_and_weight(self, shape, device, sds=False):
        rnd_normal = torch.randn([shape, 1, 1, 1], device=device)
        sigma = (rnd_normal * self.P_std + self.P_mean).exp()
        weight = (sigma ** 2 + self.sigma_data ** 2) / (sigma * self.sigma_data) ** 2
        return sigma.float(), weight.float()

    def __call__(self, net, x, latents, augment_pipe=None):
        sigma, weight = self.noise_and_weight(x.shape[0], x.device)
        n = torch.randn_like(x) * sigma
        D_xn = net(x + n, sigma, latents)
        loss = weight * ((D_xn - x) ** 2)
        return loss

#----------------------------------------------------------------------------