Spaces:
Runtime error
Runtime error
feat: log num_parameters early
Browse files- tools/train/train.py +32 -31
tools/train/train.py
CHANGED
|
@@ -558,6 +558,35 @@ def main():
|
|
| 558 |
)
|
| 559 |
num_params = model.num_params
|
| 560 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 561 |
# Create learning rate schedule
|
| 562 |
def create_learning_rate_fn() -> Callable[[int], jnp.array]:
|
| 563 |
"""Create the learning rate function."""
|
|
@@ -915,42 +944,14 @@ def main():
|
|
| 915 |
out_axis_resources=None,
|
| 916 |
)
|
| 917 |
|
| 918 |
-
logger.info("***** Running training *****")
|
| 919 |
-
logger.info(f" Num examples = {len_train_dataset}")
|
| 920 |
-
logger.info(f" Num Epochs = {num_epochs}")
|
| 921 |
-
logger.info(
|
| 922 |
-
f" Batch size per device = {training_args.per_device_train_batch_size}"
|
| 923 |
-
)
|
| 924 |
-
logger.info(f" Number of devices = {jax.device_count()}")
|
| 925 |
-
logger.info(
|
| 926 |
-
f" Gradient accumulation steps = {training_args.gradient_accumulation_steps}"
|
| 927 |
-
)
|
| 928 |
-
logger.info(f" Batch size per update = {batch_size_per_step}")
|
| 929 |
-
logger.info(f" Model parameters = {num_params:,}")
|
| 930 |
-
epochs = tqdm(
|
| 931 |
-
range(state.epoch, num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0
|
| 932 |
-
)
|
| 933 |
-
|
| 934 |
# init variables
|
| 935 |
last_time = time.perf_counter()
|
| 936 |
train_metrics = None
|
| 937 |
step = int(state.step)
|
| 938 |
metrics_logger = MetricsLogger(step)
|
| 939 |
-
|
| 940 |
-
|
| 941 |
-
|
| 942 |
-
wandb.define_metric("*", step_metric="train/step")
|
| 943 |
-
|
| 944 |
-
# add interesting config parameters
|
| 945 |
-
wandb.config.update(
|
| 946 |
-
{
|
| 947 |
-
"len_train_dataset": len_train_dataset,
|
| 948 |
-
"len_eval_dataset": len_eval_dataset,
|
| 949 |
-
"batch_size_per_step": batch_size_per_step,
|
| 950 |
-
"num_params": num_params,
|
| 951 |
-
"num_devices": jax.device_count(),
|
| 952 |
-
}
|
| 953 |
-
)
|
| 954 |
|
| 955 |
def run_evaluation():
|
| 956 |
# ======================== Evaluating ==============================
|
|
|
|
| 558 |
)
|
| 559 |
num_params = model.num_params
|
| 560 |
|
| 561 |
+
logger.info("***** Running training *****")
|
| 562 |
+
logger.info(f" Num examples = {len_train_dataset}")
|
| 563 |
+
logger.info(f" Num Epochs = {num_epochs}")
|
| 564 |
+
logger.info(
|
| 565 |
+
f" Batch size per device = {training_args.per_device_train_batch_size}"
|
| 566 |
+
)
|
| 567 |
+
logger.info(f" Number of devices = {jax.device_count()}")
|
| 568 |
+
logger.info(
|
| 569 |
+
f" Gradient accumulation steps = {training_args.gradient_accumulation_steps}"
|
| 570 |
+
)
|
| 571 |
+
logger.info(f" Batch size per update = {batch_size_per_step}")
|
| 572 |
+
logger.info(f" Model parameters = {num_params:,}")
|
| 573 |
+
|
| 574 |
+
# create wandb run
|
| 575 |
+
if jax.process_index() == 0:
|
| 576 |
+
# set default x-axis as 'train/step'
|
| 577 |
+
wandb.define_metric("*", step_metric="train/step")
|
| 578 |
+
|
| 579 |
+
# add interesting config parameters
|
| 580 |
+
wandb.config.update(
|
| 581 |
+
{
|
| 582 |
+
"len_train_dataset": len_train_dataset,
|
| 583 |
+
"len_eval_dataset": len_eval_dataset,
|
| 584 |
+
"batch_size_per_step": batch_size_per_step,
|
| 585 |
+
"num_params": num_params,
|
| 586 |
+
"num_devices": jax.device_count(),
|
| 587 |
+
}
|
| 588 |
+
)
|
| 589 |
+
|
| 590 |
# Create learning rate schedule
|
| 591 |
def create_learning_rate_fn() -> Callable[[int], jnp.array]:
|
| 592 |
"""Create the learning rate function."""
|
|
|
|
| 944 |
out_axis_resources=None,
|
| 945 |
)
|
| 946 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 947 |
# init variables
|
| 948 |
last_time = time.perf_counter()
|
| 949 |
train_metrics = None
|
| 950 |
step = int(state.step)
|
| 951 |
metrics_logger = MetricsLogger(step)
|
| 952 |
+
epochs = tqdm(
|
| 953 |
+
range(state.epoch, num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0
|
| 954 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 955 |
|
| 956 |
def run_evaluation():
|
| 957 |
# ======================== Evaluating ==============================
|