updated 50 train files
Browse files- src/train.py +1 -1
src/train.py
CHANGED
|
@@ -206,7 +206,7 @@ def main():
|
|
| 206 |
if idx%5==0:
|
| 207 |
summary_writer.scalar('train_loss',flax.jax_utils.unreplicate(train_metric)['loss'].item(),idx)
|
| 208 |
summary_writer.scalar('train_accuracy', jax.device_get(train_metric['accuracy']).mean().item(),idx)
|
| 209 |
-
if idx%
|
| 210 |
logger.info(f"train_step_loss{idx}: {flax.jax_utils.unreplicate(train_metric)['loss'].item()} train_step_acc{idx}: {jax.device_get(train_metric['accuracy']).mean().item()}")
|
| 211 |
|
| 212 |
progress_bar_train.update(1)
|
|
|
|
| 206 |
if idx%5==0:
|
| 207 |
summary_writer.scalar('train_loss',flax.jax_utils.unreplicate(train_metric)['loss'].item(),idx)
|
| 208 |
summary_writer.scalar('train_accuracy', jax.device_get(train_metric['accuracy']).mean().item(),idx)
|
| 209 |
+
if idx%50==0:
|
| 210 |
logger.info(f"train_step_loss{idx}: {flax.jax_utils.unreplicate(train_metric)['loss'].item()} train_step_acc{idx}: {jax.device_get(train_metric['accuracy']).mean().item()}")
|
| 211 |
|
| 212 |
progress_bar_train.update(1)
|