added rng key to the eval_step
Browse files
train.py
CHANGED
|
@@ -597,7 +597,7 @@ def main():
|
|
| 597 |
def eval_step(params, batch):
|
| 598 |
labels = batch.pop("labels")
|
| 599 |
logits, latent_codes = model(**batch, params=params, train=False)[:2]
|
| 600 |
-
loss = loss_fn(logits, labels, latent_codes)
|
| 601 |
|
| 602 |
# summarize metrics
|
| 603 |
metrics = {"loss": loss}
|
|
|
|
| 597 |
def eval_step(params, batch):
|
| 598 |
labels = batch.pop("labels")
|
| 599 |
logits, latent_codes = model(**batch, params=params, train=False)[:2]
|
| 600 |
+
loss = loss_fn(logits, labels, latent_codes, rng)
|
| 601 |
|
| 602 |
# summarize metrics
|
| 603 |
metrics = {"loss": loss}
|