gigant commited on
Commit
cb7eca3
·
1 Parent(s): c0f0a4a

vmapped the mmd loss function

Browse files
Files changed (1) hide show
  1. train.py +1 -1
train.py CHANGED
@@ -563,7 +563,7 @@ def main():
563
 
564
  def regulariser_loss(latent_codes, rng):
565
  true_samples = jax.random.normal(rng, latent_codes.shape)
566
- return compute_mmd(true_samples, latent_codes)
567
 
568
  def loss_fn(logits, labels, latent_codes, regulariser_rng):
569
  shift_logits = logits[..., :-1, :]
 
563
 
564
  def regulariser_loss(latent_codes, rng):
565
  true_samples = jax.random.normal(rng, latent_codes.shape)
566
+ return jax.vmap(compute_mmd)(true_samples, latent_codes)
567
 
568
  def loss_fn(logits, labels, latent_codes, regulariser_rng):
569
  shift_logits = logits[..., :-1, :]