Spaces:
Runtime error
Runtime error
feat: minor improvements
Browse files
dalle_mini/model/configuration.py
CHANGED
|
@@ -80,7 +80,6 @@ class DalleBartConfig(PretrainedConfig):
|
|
| 80 |
self.decoder_layerdrop = decoder_layerdrop
|
| 81 |
self.classifier_dropout = classifier_dropout
|
| 82 |
self.use_cache = use_cache
|
| 83 |
-
self.num_hidden_layers = encoder_layers
|
| 84 |
self.gradient_checkpointing = gradient_checkpointing
|
| 85 |
self.scale_embedding = (
|
| 86 |
scale_embedding # scale factor will be sqrt(d_model) if True
|
|
|
|
| 80 |
self.decoder_layerdrop = decoder_layerdrop
|
| 81 |
self.classifier_dropout = classifier_dropout
|
| 82 |
self.use_cache = use_cache
|
|
|
|
| 83 |
self.gradient_checkpointing = gradient_checkpointing
|
| 84 |
self.scale_embedding = (
|
| 85 |
scale_embedding # scale factor will be sqrt(d_model) if True
|
tools/train/train.py
CHANGED
|
@@ -375,6 +375,9 @@ def main():
|
|
| 375 |
datasets.utils.logging.set_verbosity_error()
|
| 376 |
transformers.utils.logging.set_verbosity_error()
|
| 377 |
|
|
|
|
|
|
|
|
|
|
| 378 |
# Set the verbosity to info of the Transformers logger (on main process only):
|
| 379 |
logger.info(f"Training/evaluation parameters {training_args}")
|
| 380 |
|
|
@@ -443,9 +446,6 @@ def main():
|
|
| 443 |
use_fast=True,
|
| 444 |
)
|
| 445 |
|
| 446 |
-
logger.info(f"TPUs: {jax.device_count()}")
|
| 447 |
-
assert jax.device_count() == 8, "TPUs in use, please check running processes"
|
| 448 |
-
|
| 449 |
# Preprocessing the datasets.
|
| 450 |
# We need to normalize and tokenize inputs and targets.
|
| 451 |
|
|
@@ -474,6 +474,7 @@ def main():
|
|
| 474 |
num_train_steps = (
|
| 475 |
steps_per_epoch * num_epochs if steps_per_epoch is not None else None
|
| 476 |
)
|
|
|
|
| 477 |
|
| 478 |
# Create learning rate schedule
|
| 479 |
learning_rate_fn = create_learning_rate_fn(
|
|
@@ -602,6 +603,7 @@ def main():
|
|
| 602 |
logger.info(
|
| 603 |
f" Total train batch size (w. parallel, distributed & gradient accumulation) = {batch_size_per_update}"
|
| 604 |
)
|
|
|
|
| 605 |
epochs = tqdm(
|
| 606 |
range(state.epoch, num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0
|
| 607 |
)
|
|
@@ -616,7 +618,7 @@ def main():
|
|
| 616 |
"len_train_dataset": len_train_dataset,
|
| 617 |
"len_eval_dataset": len_eval_dataset,
|
| 618 |
"batch_size_per_update": batch_size_per_update,
|
| 619 |
-
"num_params":
|
| 620 |
}
|
| 621 |
)
|
| 622 |
|
|
@@ -693,7 +695,7 @@ def main():
|
|
| 693 |
c.cleanup(wandb.util.from_human_size("10GB"))
|
| 694 |
|
| 695 |
metadata = dict(state_dict)
|
| 696 |
-
metadata["num_params"] =
|
| 697 |
if eval_metrics is not None:
|
| 698 |
metadata["eval"] = eval_metrics
|
| 699 |
artifact = wandb.Artifact(
|
|
|
|
| 375 |
datasets.utils.logging.set_verbosity_error()
|
| 376 |
transformers.utils.logging.set_verbosity_error()
|
| 377 |
|
| 378 |
+
logger.info(f"TPUs: {jax.device_count()}")
|
| 379 |
+
assert jax.device_count() == 8, "TPUs in use, please check running processes"
|
| 380 |
+
|
| 381 |
# Set the verbosity to info of the Transformers logger (on main process only):
|
| 382 |
logger.info(f"Training/evaluation parameters {training_args}")
|
| 383 |
|
|
|
|
| 446 |
use_fast=True,
|
| 447 |
)
|
| 448 |
|
|
|
|
|
|
|
|
|
|
| 449 |
# Preprocessing the datasets.
|
| 450 |
# We need to normalize and tokenize inputs and targets.
|
| 451 |
|
|
|
|
| 474 |
num_train_steps = (
|
| 475 |
steps_per_epoch * num_epochs if steps_per_epoch is not None else None
|
| 476 |
)
|
| 477 |
+
num_params = model.num_params
|
| 478 |
|
| 479 |
# Create learning rate schedule
|
| 480 |
learning_rate_fn = create_learning_rate_fn(
|
|
|
|
| 603 |
logger.info(
|
| 604 |
f" Total train batch size (w. parallel, distributed & gradient accumulation) = {batch_size_per_update}"
|
| 605 |
)
|
| 606 |
+
logger.info(f" Model parameters = {num_params:,}")
|
| 607 |
epochs = tqdm(
|
| 608 |
range(state.epoch, num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0
|
| 609 |
)
|
|
|
|
| 618 |
"len_train_dataset": len_train_dataset,
|
| 619 |
"len_eval_dataset": len_eval_dataset,
|
| 620 |
"batch_size_per_update": batch_size_per_update,
|
| 621 |
+
"num_params": num_params,
|
| 622 |
}
|
| 623 |
)
|
| 624 |
|
|
|
|
| 695 |
c.cleanup(wandb.util.from_human_size("10GB"))
|
| 696 |
|
| 697 |
metadata = dict(state_dict)
|
| 698 |
+
metadata["num_params"] = num_params
|
| 699 |
if eval_metrics is not None:
|
| 700 |
metadata["eval"] = eval_metrics
|
| 701 |
artifact = wandb.Artifact(
|