gigant commited on
Commit
1e7d7f7
·
1 Parent(s): cb7eca3

added rng key to the eval_step

Browse files
Files changed (1) hide show
  1. train.py +1 -1
train.py CHANGED
@@ -597,7 +597,7 @@ def main():
597
  def eval_step(params, batch):
598
  labels = batch.pop("labels")
599
  logits, latent_codes = model(**batch, params=params, train=False)[:2]
600
- loss = loss_fn(logits, labels, latent_codes)
601
 
602
  # summarize metrics
603
  metrics = {"loss": loss}
 
597
  def eval_step(params, batch):
598
  labels = batch.pop("labels")
599
  logits, latent_codes = model(**batch, params=params, train=False)[:2]
600
+ loss = loss_fn(logits, labels, latent_codes, rng)
601
 
602
  # summarize metrics
603
  metrics = {"loss": loss}