Spaces:
Paused
Paused
resolve name conflict
Browse files- diffusion_sampler.py +3 -3
diffusion_sampler.py
CHANGED
|
@@ -98,10 +98,10 @@ class DiffusionSampler(keras.Model):
|
|
| 98 |
return sqrt_alpha_cum_prod_prev * x0_t + x_t_dir * pred_noise + c1 * random_noise
|
| 99 |
|
| 100 |
def noise_scheduler(self, scheduler: str, max_beta: int = 0.02) -> tf.Tensor:
|
| 101 |
-
pi,
|
| 102 |
|
| 103 |
-
alpha_bar = lambda
|
| 104 |
-
cosine_scheduler = lambda
|
| 105 |
|
| 106 |
if scheduler == "linear":
|
| 107 |
x = tf.linspace(start=self.beta_start, stop=self.beta_end, num=self.timesteps)
|
|
|
|
| 98 |
return sqrt_alpha_cum_prod_prev * x0_t + x_t_dir * pred_noise + c1 * random_noise
|
| 99 |
|
| 100 |
def noise_scheduler(self, scheduler: str, max_beta: int = 0.02) -> tf.Tensor:
|
| 101 |
+
pi, T = [tf.constant(num, dtype=tf.float64) for num in (math.pi, self.timesteps)]
|
| 102 |
|
| 103 |
+
alpha_bar = lambda i: tf.math.cos((i + 0.008) / 1.008 * pi / 2) ** 2
|
| 104 |
+
cosine_scheduler = lambda t: tf.minimum(1 - alpha_bar((t + 1) / T) / alpha_bar(t / T), max_beta)
|
| 105 |
|
| 106 |
if scheduler == "linear":
|
| 107 |
x = tf.linspace(start=self.beta_start, stop=self.beta_end, num=self.timesteps)
|