leowajda commited on
Commit
fc74cfd
·
1 Parent(s): 0948b11

reduce tracing

Browse files
Files changed (1) hide show
  1. diffusion_sampler.py +1 -1
diffusion_sampler.py CHANGED
@@ -118,7 +118,7 @@ class DiffusionSampler(keras.Model):
118
  sqrt_one_minus_alpha_cum_prod = at_timestep(self.sqrt_one_minus_alphas_cum_prod)
119
  return sqrt_alpha_cum_prod * x_start + sqrt_one_minus_alpha_cum_prod * noise
120
 
121
- @tf.function()
122
  def generate_images(
123
  self,
124
  num_images: int,
 
118
  sqrt_one_minus_alpha_cum_prod = at_timestep(self.sqrt_one_minus_alphas_cum_prod)
119
  return sqrt_alpha_cum_prod * x_start + sqrt_one_minus_alpha_cum_prod * noise
120
 
121
+ @tf.function(reduce_tracing=True)
122
  def generate_images(
123
  self,
124
  num_images: int,