Commit ·
5e0ceba
1
Parent(s): 5e05341
check if ds is load
Browse files
run_speech_recognition_seq2seq_streaming.py
CHANGED
|
@@ -387,6 +387,7 @@ def main():
|
|
| 387 |
raw_datasets = IterableDatasetDict() if data_args.streaming else DatasetDict()
|
| 388 |
|
| 389 |
if training_args.do_train:
|
|
|
|
| 390 |
raw_datasets["train"] = load_maybe_streaming_dataset(
|
| 391 |
data_args.dataset_name,
|
| 392 |
data_args.dataset_config_name,
|
|
@@ -394,8 +395,10 @@ def main():
|
|
| 394 |
use_auth_token=True if model_args.use_auth_token else None,
|
| 395 |
streaming=data_args.streaming,
|
| 396 |
)
|
|
|
|
| 397 |
|
| 398 |
if training_args.do_eval:
|
|
|
|
| 399 |
raw_datasets["eval"] = load_maybe_streaming_dataset(
|
| 400 |
data_args.dataset_name,
|
| 401 |
data_args.dataset_config_name,
|
|
@@ -403,6 +406,7 @@ def main():
|
|
| 403 |
use_auth_token=True if model_args.use_auth_token else None,
|
| 404 |
streaming=data_args.streaming,
|
| 405 |
)
|
|
|
|
| 406 |
|
| 407 |
raw_datasets_features = list(next(iter(raw_datasets.values())).features.keys())
|
| 408 |
|
|
@@ -602,8 +606,12 @@ def main():
|
|
| 602 |
callbacks=[ShuffleCallback()] if data_args.streaming else None,
|
| 603 |
)
|
| 604 |
|
|
|
|
|
|
|
|
|
|
| 605 |
# 12. Training
|
| 606 |
if training_args.do_train:
|
|
|
|
| 607 |
checkpoint = None
|
| 608 |
if training_args.resume_from_checkpoint is not None:
|
| 609 |
checkpoint = training_args.resume_from_checkpoint
|
|
@@ -651,6 +659,7 @@ def main():
|
|
| 651 |
if model_args.model_index_name is not None:
|
| 652 |
kwargs["model_name"] = model_args.model_index_name
|
| 653 |
|
|
|
|
| 654 |
if training_args.push_to_hub:
|
| 655 |
trainer.push_to_hub(**kwargs)
|
| 656 |
else:
|
|
|
|
| 387 |
raw_datasets = IterableDatasetDict() if data_args.streaming else DatasetDict()
|
| 388 |
|
| 389 |
if training_args.do_train:
|
| 390 |
+
logger.info("Loading training dataset")
|
| 391 |
raw_datasets["train"] = load_maybe_streaming_dataset(
|
| 392 |
data_args.dataset_name,
|
| 393 |
data_args.dataset_config_name,
|
|
|
|
| 395 |
use_auth_token=True if model_args.use_auth_token else None,
|
| 396 |
streaming=data_args.streaming,
|
| 397 |
)
|
| 398 |
+
logger.info("Loaded training dataset")
|
| 399 |
|
| 400 |
if training_args.do_eval:
|
| 401 |
+
logger.info("Loading evaluation dataset")
|
| 402 |
raw_datasets["eval"] = load_maybe_streaming_dataset(
|
| 403 |
data_args.dataset_name,
|
| 404 |
data_args.dataset_config_name,
|
|
|
|
| 406 |
use_auth_token=True if model_args.use_auth_token else None,
|
| 407 |
streaming=data_args.streaming,
|
| 408 |
)
|
| 409 |
+
logger.info("Loaded evaluation dataset")
|
| 410 |
|
| 411 |
raw_datasets_features = list(next(iter(raw_datasets.values())).features.keys())
|
| 412 |
|
|
|
|
| 606 |
callbacks=[ShuffleCallback()] if data_args.streaming else None,
|
| 607 |
)
|
| 608 |
|
| 609 |
+
push_to_hub = training_args.push_to_hub
|
| 610 |
+
training_args.push_to_hub = False
|
| 611 |
+
|
| 612 |
# 12. Training
|
| 613 |
if training_args.do_train:
|
| 614 |
+
logger.info("*** Train ***")
|
| 615 |
checkpoint = None
|
| 616 |
if training_args.resume_from_checkpoint is not None:
|
| 617 |
checkpoint = training_args.resume_from_checkpoint
|
|
|
|
| 659 |
if model_args.model_index_name is not None:
|
| 660 |
kwargs["model_name"] = model_args.model_index_name
|
| 661 |
|
| 662 |
+
training_args.push_to_hub = push_to_hub
|
| 663 |
if training_args.push_to_hub:
|
| 664 |
trainer.push_to_hub(**kwargs)
|
| 665 |
else:
|