vmapped the mmd loss function
Browse files
train.py
CHANGED
|
@@ -563,7 +563,7 @@ def main():
|
|
| 563 |
|
| 564 |
def regulariser_loss(latent_codes, rng):
|
| 565 |
true_samples = jax.random.normal(rng, latent_codes.shape)
|
| 566 |
-
return compute_mmd(true_samples, latent_codes)
|
| 567 |
|
| 568 |
def loss_fn(logits, labels, latent_codes, regulariser_rng):
|
| 569 |
shift_logits = logits[..., :-1, :]
|
|
|
|
| 563 |
|
| 564 |
def regulariser_loss(latent_codes, rng):
|
| 565 |
true_samples = jax.random.normal(rng, latent_codes.shape)
|
| 566 |
+
return jax.vmap(compute_mmd)(true_samples, latent_codes)
|
| 567 |
|
| 568 |
def loss_fn(logits, labels, latent_codes, regulariser_rng):
|
| 569 |
shift_logits = logits[..., :-1, :]
|