File size: 5,163 Bytes
d382778
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
import torch
from samplers.general_solver import ODESolver

class DPM_Solver(ODESolver):
    def __init__(
        self,
        noise_schedule,
        algorithm_type="noise_prediction", # need to be noise prediction!
    ):
        super().__init__(noise_schedule, algorithm_type)
        self.noise_schedule = noise_schedule


    def compute_K_and_order(self, steps, order):
        assert order in [1, 2]
        if order == 1:
            K = steps
            orders = [1,] * steps 
        elif order == 2:
            if steps % 2 == 0:
                K = steps // 2
                orders = [2,] * K
            else:
                K = steps // 2 + 1
                orders = [2,] * (K - 1) + [1]
        return K, orders


    def get_orders_and_timesteps_for_singlestep_solver(self, steps, order, skip_type, t_T, t_0, flags, device):
        '''
        steps: NFEs
        '''
        # order == 1 means DDIM (DPM-Solver-1)
        # order == 2 means DPM-Solver-2
        K, orders = self.compute_K_and_order(steps, order)
        timesteps_outer, timesteps_outer2, rs, rs2 = self.prepare_timesteps_single(steps=K, NFEs=steps, t_start=t_T, t_end=t_0, flags=flags, device=device, skip_type=skip_type)
        return timesteps_outer, timesteps_outer2, rs, rs2, orders


    def dpm_solver_first_update(self, x, s1, s2, t1, model_s=None):
        ns = self.noise_schedule
        lambda_s, lambda_t = ns.marginal_lambda(s1), ns.marginal_lambda(t1)
        h = lambda_t - lambda_s
        log_alpha_s, log_alpha_t = ns.marginal_log_mean_coeff(s1), ns.marginal_log_mean_coeff(t1)
        sigma_t = ns.marginal_std(t1)

        phi_1 = torch.expm1(h)
        if model_s is None:
            model_s = self.model_fn(x, s2) # noise prediction!
        x_t = (
            torch.exp(log_alpha_t - log_alpha_s) * x 
            - (sigma_t * phi_1) * model_s
        )
        
        return x_t 


    def dpm_solver_second_update(self, x, s1, s2, t1, r1=0.5, r2=0.5, model_s=None):
        ns = self.noise_schedule
        lambda_s, lambda_t = ns.marginal_lambda(s1), ns.marginal_lambda(t1)
        h = lambda_t - lambda_s
        lambda_s_inter1 = lambda_s + r1 * h
        lambda_s_inter2 = lambda_s + r2 * h
        s_inter1 = ns.inverse_lambda(lambda_s_inter1)
        s_inter2 = ns.inverse_lambda(lambda_s_inter2)
        
        log_alpha_s, log_alpha_s_inter, log_alpha_t = ns.marginal_log_mean_coeff(s1), ns.marginal_log_mean_coeff(s_inter1), ns.marginal_log_mean_coeff(t1)
        sigma_s_inter, sigma_t = ns.marginal_std(s_inter1), ns.marginal_std(t1)
        
        phi_1_inter = torch.expm1(r1 * h)
        phi_1 = torch.expm1(h)
        
        if model_s is None:
            model_s = self.model_fn(x, s2)
        
        x_s_inter = (
            torch.exp(log_alpha_s_inter - log_alpha_s) * x 
            - (sigma_s_inter * phi_1_inter) * model_s
        )
        
        model_s_inter = self.model_fn(x_s_inter, s_inter2)
        x_t = (
            torch.exp(log_alpha_t - log_alpha_s) * x 
            - (sigma_t * phi_1) * model_s 
            - (0.5 / r1) * (sigma_t * phi_1) * (model_s_inter - model_s)
        )
        
        return x_t
    
    
    def singlestep_dpm_solver_update(self, x, s1, s2, t1, order, r1=0.5, r2=0.5, model_s=None):
        if order == 1:
            x_t = self.dpm_solver_first_update(x, s1, s2, t1, model_s=model_s)
        elif order == 2:
            x_t = self.dpm_solver_second_update(x, s1, s2, t1, r1, r2, model_s=model_s)
        else:
            raise ValueError("Order must be 1 or 2.")
        return x_t
        

    def sample(
        self,
        model_fn,
        x,
        steps=20,
        t_start=None,
        t_end=None,
        order=2,
        skip_type="time_uniform",
        flags=None,
    ):
        # check if order is 2 
        assert order == 2
        self.model = lambda x, t: model_fn(x, t.expand((x.shape[0])))
        t_0 = 1.0 / self.noise_schedule.total_N if t_end is None else t_end
        t_T = self.noise_schedule.T if t_start is None else t_start
        device = x.device
        
        timesteps, timesteps2, rs, rs2, orders = self.get_orders_and_timesteps_for_singlestep_solver(steps, order, skip_type, t_T, t_0, flags, device)
        with torch.no_grad():
            return self.sample_simple(model_fn, x, orders, timesteps, timesteps2, rs, rs2)
        
        
    def sample_simple(self, model_fn, x, timesteps, timesteps2, order=2, rs=None, rs2=None, condition=None, unconditional_condition=None, **kwargs):
        '''
        order is a list of order
        '''
        self.model = lambda x, t: model_fn(x, t.expand((x.shape[0])), condition, unconditional_condition)
        
        if rs is None:
            rs = [0.5,] * len(timesteps)
        if rs2 is None:
            rs2 = [0.5,] * len(timesteps)
        
        orders = order 
        
        for step, od in enumerate(orders):
            s1, t1 = timesteps[step], timesteps[step + 1]
            s2 = timesteps2[step]
            r1, r2 = rs[step], rs2[step]
            x = self.singlestep_dpm_solver_update(x, s1, s2, t1, od, r1, r2)
        return x