Vivek commited on
Commit
f8dd798
·
1 Parent(s): 4892952

updated 50 train files

Browse files
Files changed (1) hide show
  1. src/train.py +1 -1
src/train.py CHANGED
@@ -206,7 +206,7 @@ def main():
206
  if idx%5==0:
207
  summary_writer.scalar('train_loss',flax.jax_utils.unreplicate(train_metric)['loss'].item(),idx)
208
  summary_writer.scalar('train_accuracy', jax.device_get(train_metric['accuracy']).mean().item(),idx)
209
- if idx%20==0:
210
  logger.info(f"train_step_loss{idx}: {flax.jax_utils.unreplicate(train_metric)['loss'].item()} train_step_acc{idx}: {jax.device_get(train_metric['accuracy']).mean().item()}")
211
 
212
  progress_bar_train.update(1)
 
206
  if idx%5==0:
207
  summary_writer.scalar('train_loss',flax.jax_utils.unreplicate(train_metric)['loss'].item(),idx)
208
  summary_writer.scalar('train_accuracy', jax.device_get(train_metric['accuracy']).mean().item(),idx)
209
+ if idx%50==0:
210
  logger.info(f"train_step_loss{idx}: {flax.jax_utils.unreplicate(train_metric)['loss'].item()} train_step_acc{idx}: {jax.device_get(train_metric['accuracy']).mean().item()}")
211
 
212
  progress_bar_train.update(1)