File size: 6,186 Bytes
afe65cf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
import numpy as np
import abc
import functools

class SDE(abc.ABC):
  """SDE abstract class. Functions are designed for a mini-batch of inputs."""

  def __init__(self, N):
    """Construct an SDE.

    Args:
      N: number of discretization time steps.
    """
    super().__init__()
    self.N = N

  @property
  @abc.abstractmethod
  def T(self):
    """End time of the SDE."""
    pass

  @abc.abstractmethod
  def sde(self, x, t):
    pass

  @abc.abstractmethod
  def marginal_prob(self, x, t):
    """Parameters to determine the marginal distribution of the SDE, $p_t(x)$."""
    pass

  @abc.abstractmethod
  def prior_sampling(self, rng, shape):
    """Generate one sample from the prior distribution, $p_T(x)$."""
    pass

  def reverse(self, score_fn, probability_flow=False):
    """Create the reverse-time SDE/ODE.

    Args:
      score_fn: A time-dependent score-based model that takes x and t and returns the score.
      probability_flow: If `True`, create the reverse-time ODE used for probability flow sampling.
    """
    N = self.N
    T = self.T
    sde_fn = self.sde

    # Build the class for reverse-time SDE.
    class RSDE(self.__class__):
      def __init__(self):
        self.N = N
        self.probability_flow = probability_flow

      @property
      def T(self):
        return T

      def sde(self, x, t):
        """Create the drift and diffusion functions for the reverse SDE/ODE."""
        drift, diffusion = sde_fn(x, t)
        score = score_fn(y=x, t=t)
        score = score_fn(y=x, t=t)
        drift = drift - (diffusion ** 2)*(score * (0.5 if self.probability_flow else 1.))
        # Set the diffusion function to zero for ODEs.
        diffusion = np.zeros_like(diffusion) if self.probability_flow else diffusion
        return drift, diffusion

    return RSDE()

class VPSDE(SDE):
  def __init__(self, beta_min=0.1, beta_max=20, N=1000):
    """Construct a Variance Preserving SDE.

    Args:
      beta_min: value of beta(0)
      beta_max: value of beta(1)
      N: number of discretization steps
    """
    super().__init__(N)
    self.beta_0 = beta_min
    self.beta_1 = beta_max
    self.N = N
    self.discrete_betas = np.linspace(beta_min / N, beta_max / N, N)
    self.alphas = 1. - self.discrete_betas
    self.alphas_cumprod = np.cumprod(self.alphas, axis=0)
    self.sqrt_alphas_cumprod = np.sqrt(self.alphas_cumprod)
    self.sqrt_1m_alphas_cumprod = np.sqrt(1. - self.alphas_cumprod)

  @property
  def T(self):
    return 1

  def sde(self, x, t):
    beta_t = self.beta_0 + t * (self.beta_1 - self.beta_0)
    drift = -0.5 * beta_t * x
    diffusion = np.sqrt(beta_t)
    return drift, diffusion

  def marginal_prob(self, x, t):
    log_mean_coeff = -0.25 * t ** 2 * (self.beta_1 - self.beta_0) - 0.5 * t * self.beta_0
    mean = np.exp(log_mean_coeff)*x
    std = np.sqrt(1 - np.exp(2. * log_mean_coeff))
    return mean, std

  def marginal_prob_coef(self, x, t):
    log_mean_coeff = -0.25 * t ** 2 * (self.beta_1 - self.beta_0) - 0.5 * t * self.beta_0
    mean = np.exp(log_mean_coeff)
    std = np.sqrt(1 - np.exp(2. * log_mean_coeff))
    return mean, std

  def prior_sampling(self, shape):
    return np.random.normal(size=shape)

class Predictor(abc.ABC):
  """The abstract class for a predictor algorithm."""

  def __init__(self, sde, score_fn):
    super().__init__()
    self.sde = sde
    # Compute the reverse SDE/ODE
    self.rsde = sde.reverse(score_fn)
    self.score_fn = score_fn

  @abc.abstractmethod
  def update_fn(self, x, t):
    pass

class EulerMaruyamaPredictor(Predictor):
  def __init__(self, sde, score_fn):
    super().__init__(sde, score_fn)

  def update_fn(self, x, t, h):
    my_sde = self.rsde.sde
    z = self.sde.prior_sampling(x.shape)
    drift, diffusion = my_sde(x, t)
    x_mean = x - drift * h
    x = x_mean + diffusion*np.sqrt(h)*z
    return x, x_mean

def shared_predictor_update_fn(x, t, h=None, sde=None, score_fn=None,):
  """A wrapper that configures and returns the update function of predictors."""
  predictor_obj = EulerMaruyamaPredictor(sde, score_fn)
  return predictor_obj.update_fn(x, t, h)

# VP sampler

def get_pc_sampler(score_fn, sde, denoise=True, eps=1e-3, repaint=False):

  predictor_update_fn = functools.partial(shared_predictor_update_fn,
                                          sde=sde,
                                          score_fn=score_fn)

  def pc_sampler(prior, r=5, j=5):
    # Initial sample
    x = prior
    timesteps = np.linspace(sde.T, eps, sde.N)
    h = timesteps - np.append(timesteps, 0)[1:] # true step-size: difference between current time and next time (only the new predictor classes will use h, others will ignore)
    N = sde.N - 1

    for i in range(N):
      x, x_mean = predictor_update_fn(x, timesteps[i], h[i])

    if denoise: # Tweedie formula
      u, std = sde.marginal_prob(x, eps)
      x = x + (std ** 2)*score_fn(y=x, t=eps)

    return x

  def pc_sampler_repaint(prior, r=5, j=5):
    # Initial sample
    x = prior
    timesteps = np.linspace(sde.T, eps, sde.N)
    h = timesteps - np.append(timesteps, 0)[1:] # true step-size: difference between current time and next time (only the new predictor classes will use h, others will ignore)
    N = sde.N - 1

    i_repaint = 0
    i = 0
    while i < N:
      x, x_mean = predictor_update_fn(x, timesteps[i], h[i])
      if i_repaint < r-1 and (i+1) % j == 0: # we did j iterations, but not enough repaint, we must repaint again
        # Going backward in time; using Euler-Maruyama
        z = sde.prior_sampling(x.shape)
        drift, diffusion = sde.sde(x, timesteps[i])
        h_ = sum(h[(i-j+1):(i+1)])
        x_mean = x + drift * h_
        x = x_mean + diffusion*np.sqrt(h_)*z
        # iterate back
        i_repaint = i_repaint + 1
        i = i - j
      elif i_repaint == r-1 and (i+1) % j == 0: # we did j iterations and enough repaint, we continue and reset the repaint counter
        i_repaint = 0
      i = i + 1

    if denoise: # Tweedie formula
      u, std = sde.marginal_prob(x, eps)
      x = x + (std ** 2)*score_fn(y=x, t=eps)

    return x

  if repaint:
    return pc_sampler_repaint
  else:
    return pc_sampler