Commit ·
eeecd97
1
Parent(s): 5e0ceba
add logs
Browse files
run_speech_recognition_seq2seq_streaming.py
CHANGED
|
@@ -460,6 +460,7 @@ def main():
|
|
| 460 |
revision=model_args.model_revision,
|
| 461 |
use_auth_token=True if model_args.use_auth_token else None,
|
| 462 |
)
|
|
|
|
| 463 |
|
| 464 |
if model.config.decoder_start_token_id is None:
|
| 465 |
raise ValueError("Make sure that `config.decoder_start_token_id` is correctly defined")
|
|
@@ -498,6 +499,7 @@ def main():
|
|
| 498 |
if data_args.streaming
|
| 499 |
else raw_datasets["train"].select(range(data_args.max_train_samples))
|
| 500 |
)
|
|
|
|
| 501 |
|
| 502 |
if data_args.max_eval_samples is not None:
|
| 503 |
raw_datasets["eval"] = (
|
|
@@ -505,6 +507,7 @@ def main():
|
|
| 505 |
if data_args.streaming
|
| 506 |
else raw_datasets["eval"].select(range(data_args.max_eval_samples))
|
| 507 |
)
|
|
|
|
| 508 |
|
| 509 |
def prepare_dataset(batch):
|
| 510 |
# process audio
|
|
@@ -526,6 +529,7 @@ def main():
|
|
| 526 |
prepare_dataset,
|
| 527 |
remove_columns=raw_datasets_features,
|
| 528 |
).with_format("torch")
|
|
|
|
| 529 |
|
| 530 |
if training_args.do_train and data_args.streaming:
|
| 531 |
# manually shuffle if streaming (done by the trainer for non-streaming)
|
|
@@ -533,6 +537,7 @@ def main():
|
|
| 533 |
buffer_size=data_args.shuffle_buffer_size,
|
| 534 |
seed=training_args.seed,
|
| 535 |
)
|
|
|
|
| 536 |
|
| 537 |
# filter training data that is shorter than min_input_length or longer than
|
| 538 |
# max_input_length
|
|
@@ -544,10 +549,12 @@ def main():
|
|
| 544 |
is_audio_in_length_range,
|
| 545 |
input_columns=["input_length"],
|
| 546 |
)
|
|
|
|
| 547 |
|
| 548 |
# 8. Load Metric
|
| 549 |
metric = evaluate.load("wer")
|
| 550 |
do_normalize_eval = data_args.do_normalize_eval
|
|
|
|
| 551 |
|
| 552 |
def compute_metrics(pred):
|
| 553 |
pred_ids = pred.predictions
|
|
@@ -577,12 +584,13 @@ def main():
|
|
| 577 |
config.save_pretrained(training_args.output_dir)
|
| 578 |
|
| 579 |
processor = AutoProcessor.from_pretrained(training_args.output_dir)
|
| 580 |
-
|
| 581 |
# 10. Define data collator
|
| 582 |
data_collator = DataCollatorSpeechSeq2SeqWithPadding(
|
| 583 |
processor=processor,
|
| 584 |
decoder_start_token_id=model.config.decoder_start_token_id,
|
| 585 |
)
|
|
|
|
| 586 |
|
| 587 |
# 11. Configure Trainer
|
| 588 |
# Trainer callback to reinitialise and reshuffle the streamable datasets at the beginning of each epoch
|
|
@@ -594,6 +602,9 @@ def main():
|
|
| 594 |
elif isinstance(train_dataloader.dataset, IterableDataset):
|
| 595 |
train_dataloader.dataset.set_epoch(train_dataloader.dataset._epoch + 1)
|
| 596 |
|
|
|
|
|
|
|
|
|
|
| 597 |
# Initialize Trainer
|
| 598 |
trainer = Seq2SeqTrainer(
|
| 599 |
model=model,
|
|
@@ -605,9 +616,7 @@ def main():
|
|
| 605 |
compute_metrics=compute_metrics if training_args.predict_with_generate else None,
|
| 606 |
callbacks=[ShuffleCallback()] if data_args.streaming else None,
|
| 607 |
)
|
| 608 |
-
|
| 609 |
-
push_to_hub = training_args.push_to_hub
|
| 610 |
-
training_args.push_to_hub = False
|
| 611 |
|
| 612 |
# 12. Training
|
| 613 |
if training_args.do_train:
|
|
@@ -643,6 +652,7 @@ def main():
|
|
| 643 |
trainer.save_metrics("eval", metrics)
|
| 644 |
|
| 645 |
# 14. Write Training Stats
|
|
|
|
| 646 |
kwargs = {
|
| 647 |
"finetuned_from": model_args.model_name_or_path,
|
| 648 |
"tasks": "automatic-speech-recognition",
|
|
@@ -659,11 +669,13 @@ def main():
|
|
| 659 |
if model_args.model_index_name is not None:
|
| 660 |
kwargs["model_name"] = model_args.model_index_name
|
| 661 |
|
| 662 |
-
|
|
|
|
| 663 |
if training_args.push_to_hub:
|
| 664 |
trainer.push_to_hub(**kwargs)
|
| 665 |
else:
|
| 666 |
trainer.create_model_card(**kwargs)
|
|
|
|
| 667 |
|
| 668 |
return results
|
| 669 |
|
|
|
|
| 460 |
revision=model_args.model_revision,
|
| 461 |
use_auth_token=True if model_args.use_auth_token else None,
|
| 462 |
)
|
| 463 |
+
logger.info("Loaded pretrained model, tokenizer, and feature extractor")
|
| 464 |
|
| 465 |
if model.config.decoder_start_token_id is None:
|
| 466 |
raise ValueError("Make sure that `config.decoder_start_token_id` is correctly defined")
|
|
|
|
| 499 |
if data_args.streaming
|
| 500 |
else raw_datasets["train"].select(range(data_args.max_train_samples))
|
| 501 |
)
|
| 502 |
+
logger.info("Using %d train samples", data_args.max_train_samples)
|
| 503 |
|
| 504 |
if data_args.max_eval_samples is not None:
|
| 505 |
raw_datasets["eval"] = (
|
|
|
|
| 507 |
if data_args.streaming
|
| 508 |
else raw_datasets["eval"].select(range(data_args.max_eval_samples))
|
| 509 |
)
|
| 510 |
+
logger.info("Using %d eval samples", data_args.max_eval_samples)
|
| 511 |
|
| 512 |
def prepare_dataset(batch):
|
| 513 |
# process audio
|
|
|
|
| 529 |
prepare_dataset,
|
| 530 |
remove_columns=raw_datasets_features,
|
| 531 |
).with_format("torch")
|
| 532 |
+
logger.info("Dataset map pre-processing done")
|
| 533 |
|
| 534 |
if training_args.do_train and data_args.streaming:
|
| 535 |
# manually shuffle if streaming (done by the trainer for non-streaming)
|
|
|
|
| 537 |
buffer_size=data_args.shuffle_buffer_size,
|
| 538 |
seed=training_args.seed,
|
| 539 |
)
|
| 540 |
+
logger.info("Shuffled dataset")
|
| 541 |
|
| 542 |
# filter training data that is shorter than min_input_length or longer than
|
| 543 |
# max_input_length
|
|
|
|
| 549 |
is_audio_in_length_range,
|
| 550 |
input_columns=["input_length"],
|
| 551 |
)
|
| 552 |
+
logger.info("Filtered training dataset")
|
| 553 |
|
| 554 |
# 8. Load Metric
|
| 555 |
metric = evaluate.load("wer")
|
| 556 |
do_normalize_eval = data_args.do_normalize_eval
|
| 557 |
+
logger.info("Loaded metric")
|
| 558 |
|
| 559 |
def compute_metrics(pred):
|
| 560 |
pred_ids = pred.predictions
|
|
|
|
| 584 |
config.save_pretrained(training_args.output_dir)
|
| 585 |
|
| 586 |
processor = AutoProcessor.from_pretrained(training_args.output_dir)
|
| 587 |
+
|
| 588 |
# 10. Define data collator
|
| 589 |
data_collator = DataCollatorSpeechSeq2SeqWithPadding(
|
| 590 |
processor=processor,
|
| 591 |
decoder_start_token_id=model.config.decoder_start_token_id,
|
| 592 |
)
|
| 593 |
+
logger.info("Defined data collator")
|
| 594 |
|
| 595 |
# 11. Configure Trainer
|
| 596 |
# Trainer callback to reinitialise and reshuffle the streamable datasets at the beginning of each epoch
|
|
|
|
| 602 |
elif isinstance(train_dataloader.dataset, IterableDataset):
|
| 603 |
train_dataloader.dataset.set_epoch(train_dataloader.dataset._epoch + 1)
|
| 604 |
|
| 605 |
+
push_to_hub = training_args.push_to_hub
|
| 606 |
+
training_args.push_to_hub = False
|
| 607 |
+
|
| 608 |
# Initialize Trainer
|
| 609 |
trainer = Seq2SeqTrainer(
|
| 610 |
model=model,
|
|
|
|
| 616 |
compute_metrics=compute_metrics if training_args.predict_with_generate else None,
|
| 617 |
callbacks=[ShuffleCallback()] if data_args.streaming else None,
|
| 618 |
)
|
| 619 |
+
logger.info("Initialized Trainer")
|
|
|
|
|
|
|
| 620 |
|
| 621 |
# 12. Training
|
| 622 |
if training_args.do_train:
|
|
|
|
| 652 |
trainer.save_metrics("eval", metrics)
|
| 653 |
|
| 654 |
# 14. Write Training Stats
|
| 655 |
+
logger.info("Training completed. Writing training stats")
|
| 656 |
kwargs = {
|
| 657 |
"finetuned_from": model_args.model_name_or_path,
|
| 658 |
"tasks": "automatic-speech-recognition",
|
|
|
|
| 669 |
if model_args.model_index_name is not None:
|
| 670 |
kwargs["model_name"] = model_args.model_index_name
|
| 671 |
|
| 672 |
+
logger.info("Pushing model to the hub") if push_to_hub else logger.info("Not pushing model to the hub - creating model card only")
|
| 673 |
+
trainer.args.push_to_hub = push_to_hub
|
| 674 |
if training_args.push_to_hub:
|
| 675 |
trainer.push_to_hub(**kwargs)
|
| 676 |
else:
|
| 677 |
trainer.create_model_card(**kwargs)
|
| 678 |
+
logger.info("*** DONE! ***")
|
| 679 |
|
| 680 |
return results
|
| 681 |
|