test
Browse files
run_flax_speech_recognition_seq2seq_streaming.py
CHANGED
|
@@ -66,7 +66,7 @@ from transformers.utils import check_min_version, send_example_telemetry
|
|
| 66 |
from transformers.utils.versions import require_version
|
| 67 |
|
| 68 |
import pandas as pd
|
| 69 |
-
|
| 70 |
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
| 71 |
check_min_version("4.27.0.dev0")
|
| 72 |
|
|
@@ -417,8 +417,8 @@ def collate_batch(samples):
|
|
| 417 |
def data_loader(
|
| 418 |
dataset: Dataset,
|
| 419 |
batch_size: int,
|
| 420 |
-
drop_last: bool=True,
|
| 421 |
-
num_workers: int=0,
|
| 422 |
) -> Generator:
|
| 423 |
"""
|
| 424 |
Returns batches of size `batch_size` from `dataset`. If `drop_last` is set to `False`, the final batch may be incomplete,
|
|
@@ -455,7 +455,7 @@ def write_metric(summary_writer, train_metrics, eval_metrics, train_time, step):
|
|
| 455 |
|
| 456 |
|
| 457 |
def create_learning_rate_fn(
|
| 458 |
-
num_train_steps: int, num_warmup_steps: int, learning_rate: float, warmup_init_value: float=0.0, decay_end_value: float=0.0,
|
| 459 |
) -> Callable[[int], jnp.array]:
|
| 460 |
"""Returns a linear warmup, linear_decay learning rate function."""
|
| 461 |
warmup_fn = optax.linear_schedule(
|
|
@@ -525,12 +525,15 @@ def main():
|
|
| 525 |
if training_args.push_to_hub:
|
| 526 |
if training_args.hub_model_id is None:
|
| 527 |
repo_name = get_full_repo_name(
|
| 528 |
-
Path(training_args.output_dir).absolute(
|
| 529 |
-
|
|
|
|
| 530 |
)
|
| 531 |
else:
|
| 532 |
repo_name = training_args.hub_model_id
|
| 533 |
-
create_repo(
|
|
|
|
|
|
|
| 534 |
repo = Repository(training_args.output_dir,
|
| 535 |
clone_from=repo_name, token=training_args.hub_token)
|
| 536 |
|
|
@@ -710,72 +713,86 @@ def main():
|
|
| 710 |
label_str = [label_str[i]
|
| 711 |
for i in range(len(label_str)) if len(label_str[i]) > 0]
|
| 712 |
|
| 713 |
-
wer = 100 *
|
| 714 |
-
|
|
|
|
|
|
|
| 715 |
|
| 716 |
return {"wer": wer, "cer": cer}
|
| 717 |
|
| 718 |
# TODO: Sanchit. This part is not essential and a bit verbose.
|
| 719 |
# We do howver think it is extemely useful info for debugging and improving the model
|
| 720 |
-
|
| 721 |
-
def write_predictions(step,eval_samples,eval_metrics, pred_ids, label_ids):
|
| 722 |
-
predictions_folder_name = os.path.join(
|
|
|
|
| 723 |
eval_table = f"| STEP| loss | wer |cer|\n| ---| --- | --- |--- |\n| **{step}**| {eval_metrics['loss']:.3f} | {eval_metrics['wer']:.3f} |{eval_metrics['cer']:.3f} |"
|
| 724 |
-
|
| 725 |
# Put all predictions into a table
|
| 726 |
-
inference_df = pd.DataFrame(columns=['mp3','target','prediction'])
|
| 727 |
-
|
| 728 |
idx = 0
|
| 729 |
-
for pred,label in zip(pred_ids,label_ids):
|
| 730 |
-
pred_text = tokenizer.decode(pred,skip_special_tokens=True)
|
| 731 |
-
label_text = tokenizer.decode(label,skip_special_tokens=True)
|
| 732 |
-
formatted_pred_text = " ".join(
|
|
|
|
| 733 |
#wer = 100 * metric_wer.compute(predictions=[pred_text], references=[label_text])
|
| 734 |
#wer_formatted = f'{wer:.2f}'
|
| 735 |
audio_control = f'[↓]({"mp3/pred_"+str(idx)+".mp3"})'
|
| 736 |
-
new_row = pd.DataFrame(
|
| 737 |
-
|
|
|
|
|
|
|
| 738 |
idx += 1
|
| 739 |
-
|
| 740 |
# Create the prediction table of the first N rows
|
| 741 |
-
inference_df = inference_df[['mp3', 'target','prediction']]
|
| 742 |
-
predict_table = inference_df[0:data_args.number_write_predictions].to_markdown(
|
| 743 |
-
|
|
|
|
| 744 |
# Build the markdown page
|
| 745 |
markdown_str = f"{eval_table}\n\n{predict_table}"
|
| 746 |
-
|
| 747 |
# Save the stats file
|
| 748 |
stats_file_name = f"{predictions_folder_name}/step_{step}.md"
|
| 749 |
with open(stats_file_name, "w") as f:
|
| 750 |
f.write(markdown_str)
|
| 751 |
-
|
| 752 |
# Create an header for all the files
|
| 753 |
-
md_files = sorted(os.path.basename(file) for file in os.listdir(
|
| 754 |
-
|
| 755 |
-
|
|
|
|
|
|
|
|
|
|
| 756 |
|
| 757 |
# Add this header to all the stats file in the folder
|
| 758 |
for filename in os.listdir(predictions_folder_name):
|
| 759 |
if filename.startswith("step_"):
|
| 760 |
with open(os.path.join(predictions_folder_name, filename), "r+") as f:
|
| 761 |
content = f.read()
|
| 762 |
-
new_content = md_header + "\n\n" +
|
|
|
|
| 763 |
f.seek(0)
|
| 764 |
f.write(new_content)
|
| 765 |
f.truncate()
|
| 766 |
-
|
| 767 |
# Add a folder for the mp3 files
|
| 768 |
-
if not os.path.exists(os.path.join(predictions_folder_name,"mp3")):
|
| 769 |
-
os.makedirs(os.path.join(predictions_folder_name,"mp3"))
|
| 770 |
-
|
| 771 |
for idx, mp3array in enumerate(eval_samples[0:data_args.number_write_predictions]):
|
| 772 |
-
audio_segment = AudioSegment(data=mp3array.tobytes(
|
| 773 |
-
|
|
|
|
|
|
|
| 774 |
if not os.path.exists(file_path):
|
| 775 |
audio_segment.export(file_path, format="mp3")
|
| 776 |
-
|
| 777 |
-
logger.info(
|
| 778 |
-
|
|
|
|
| 779 |
# 9. Save feature extractor, tokenizer and config
|
| 780 |
feature_extractor.save_pretrained(training_args.output_dir)
|
| 781 |
tokenizer.save_pretrained(training_args.output_dir)
|
|
@@ -800,7 +817,7 @@ def main():
|
|
| 800 |
from flax.metrics.tensorboard import SummaryWriter
|
| 801 |
|
| 802 |
summary_writer = SummaryWriter(
|
| 803 |
-
log_dir=Path(os.path.join(training_args.output_dir,"events")))
|
| 804 |
except ImportError as ie:
|
| 805 |
has_tensorboard = False
|
| 806 |
logger.warning(
|
|
@@ -963,11 +980,12 @@ def main():
|
|
| 963 |
|
| 964 |
# Clean up the prediction folder if write_predictions is set to True
|
| 965 |
if data_args.number_write_predictions:
|
| 966 |
-
predictions_folder_name = os.path.join(
|
|
|
|
| 967 |
shutil.rmtree(predictions_folder_name, ignore_errors=True)
|
| 968 |
os.makedirs(predictions_folder_name, exist_ok=True)
|
| 969 |
logger.info(f"Created folder {predictions_folder_name}")
|
| 970 |
-
|
| 971 |
# Create parallel version of the train and eval step
|
| 972 |
p_train_step = jax.pmap(
|
| 973 |
partial(train_step, label_smoothing_factor=training_args.label_smoothing_factor), "batch", donate_argnums=(0, )
|
|
@@ -995,7 +1013,8 @@ def main():
|
|
| 995 |
|
| 996 |
train_metrics = []
|
| 997 |
epoch = 0
|
| 998 |
-
train_dataset = vectorized_datasets["train"].shuffle(
|
|
|
|
| 999 |
eval_dataset = vectorized_datasets["eval"]
|
| 1000 |
train_loader = data_loader(train_dataset, train_batch_size)
|
| 1001 |
# train
|
|
@@ -1025,9 +1044,11 @@ def main():
|
|
| 1025 |
eval_preds = []
|
| 1026 |
eval_labels = []
|
| 1027 |
eval_samples = []
|
| 1028 |
-
eval_loader = data_loader(
|
|
|
|
| 1029 |
if data_args.max_eval_samples:
|
| 1030 |
-
max_eval_steps_iter = range(
|
|
|
|
| 1031 |
else:
|
| 1032 |
max_eval_steps_iter = itertools.repeat(None)
|
| 1033 |
for _ in tqdm(max_eval_steps_iter, desc="Evaluating...", position=2, leave=False):
|
|
@@ -1043,9 +1064,9 @@ def main():
|
|
| 1043 |
state.params, batch.data, min_device_batch=training_args.per_device_eval_batch_size
|
| 1044 |
)
|
| 1045 |
eval_metrics.append(metrics)
|
| 1046 |
-
if training_args.predict_with_generate and data_args.number_write_predictions and len(eval_samples)<data_args.number_write_predictions+eval_batch_size:
|
| 1047 |
eval_samples.extend(samples['input_features'])
|
| 1048 |
-
|
| 1049 |
# generation
|
| 1050 |
if training_args.predict_with_generate:
|
| 1051 |
generated_ids = pad_shard_unpad(
|
|
@@ -1074,9 +1095,10 @@ def main():
|
|
| 1074 |
if has_tensorboard and jax.process_index() == 0:
|
| 1075 |
write_metric(summary_writer, train_metrics,
|
| 1076 |
eval_metrics, train_time, step)
|
| 1077 |
-
|
| 1078 |
if training_args.predict_with_generate and data_args.number_write_predictions:
|
| 1079 |
-
write_predictions(step,eval_samples,
|
|
|
|
| 1080 |
|
| 1081 |
# save checkpoint after each epoch and push checkpoint to the hub
|
| 1082 |
if jax.process_index() == 0:
|
|
|
|
| 66 |
from transformers.utils.versions import require_version
|
| 67 |
|
| 68 |
import pandas as pd
|
| 69 |
+
|
| 70 |
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
| 71 |
check_min_version("4.27.0.dev0")
|
| 72 |
|
|
|
|
| 417 |
def data_loader(
|
| 418 |
dataset: Dataset,
|
| 419 |
batch_size: int,
|
| 420 |
+
drop_last: bool = True,
|
| 421 |
+
num_workers: int = 0,
|
| 422 |
) -> Generator:
|
| 423 |
"""
|
| 424 |
Returns batches of size `batch_size` from `dataset`. If `drop_last` is set to `False`, the final batch may be incomplete,
|
|
|
|
| 455 |
|
| 456 |
|
| 457 |
def create_learning_rate_fn(
|
| 458 |
+
num_train_steps: int, num_warmup_steps: int, learning_rate: float, warmup_init_value: float = 0.0, decay_end_value: float = 0.0,
|
| 459 |
) -> Callable[[int], jnp.array]:
|
| 460 |
"""Returns a linear warmup, linear_decay learning rate function."""
|
| 461 |
warmup_fn = optax.linear_schedule(
|
|
|
|
| 525 |
if training_args.push_to_hub:
|
| 526 |
if training_args.hub_model_id is None:
|
| 527 |
repo_name = get_full_repo_name(
|
| 528 |
+
Path(training_args.output_dir).absolute().name,
|
| 529 |
+
token=training_args.hub_token,
|
| 530 |
+
organization=training_args.push_to_hub_organization,
|
| 531 |
)
|
| 532 |
else:
|
| 533 |
repo_name = training_args.hub_model_id
|
| 534 |
+
create_repo(
|
| 535 |
+
repo_name, exist_ok=True, token=training_args.hub_token, private=training_args.hub_private_repo
|
| 536 |
+
)
|
| 537 |
repo = Repository(training_args.output_dir,
|
| 538 |
clone_from=repo_name, token=training_args.hub_token)
|
| 539 |
|
|
|
|
| 713 |
label_str = [label_str[i]
|
| 714 |
for i in range(len(label_str)) if len(label_str[i]) > 0]
|
| 715 |
|
| 716 |
+
wer = 100 * \
|
| 717 |
+
metric_wer.compute(predictions=pred_str, references=label_str)
|
| 718 |
+
cer = 100 * \
|
| 719 |
+
metric_cer.compute(predictions=pred_str, references=label_str)
|
| 720 |
|
| 721 |
return {"wer": wer, "cer": cer}
|
| 722 |
|
| 723 |
# TODO: Sanchit. This part is not essential and a bit verbose.
|
| 724 |
# We do howver think it is extemely useful info for debugging and improving the model
|
| 725 |
+
|
| 726 |
+
def write_predictions(step, eval_samples, eval_metrics, pred_ids, label_ids):
|
| 727 |
+
predictions_folder_name = os.path.join(
|
| 728 |
+
training_args.output_dir, "predictions")
|
| 729 |
eval_table = f"| STEP| loss | wer |cer|\n| ---| --- | --- |--- |\n| **{step}**| {eval_metrics['loss']:.3f} | {eval_metrics['wer']:.3f} |{eval_metrics['cer']:.3f} |"
|
| 730 |
+
|
| 731 |
# Put all predictions into a table
|
| 732 |
+
inference_df = pd.DataFrame(columns=['mp3', 'target', 'prediction'])
|
| 733 |
+
|
| 734 |
idx = 0
|
| 735 |
+
for pred, label in zip(pred_ids, label_ids):
|
| 736 |
+
pred_text = tokenizer.decode(pred, skip_special_tokens=True)
|
| 737 |
+
label_text = tokenizer.decode(label, skip_special_tokens=True)
|
| 738 |
+
formatted_pred_text = " ".join(
|
| 739 |
+
[f"**{word}**" if word in label_text else f"*{word}*" for word in pred_text.split()])
|
| 740 |
#wer = 100 * metric_wer.compute(predictions=[pred_text], references=[label_text])
|
| 741 |
#wer_formatted = f'{wer:.2f}'
|
| 742 |
audio_control = f'[↓]({"mp3/pred_"+str(idx)+".mp3"})'
|
| 743 |
+
new_row = pd.DataFrame(
|
| 744 |
+
{'mp3': audio_control, 'target': label_text, 'prediction': formatted_pred_text}, index=[0])
|
| 745 |
+
inference_df = pd.concat(
|
| 746 |
+
[inference_df, new_row], ignore_index=True)
|
| 747 |
idx += 1
|
| 748 |
+
|
| 749 |
# Create the prediction table of the first N rows
|
| 750 |
+
inference_df = inference_df[['mp3', 'target', 'prediction']]
|
| 751 |
+
predict_table = inference_df[0:data_args.number_write_predictions].to_markdown(
|
| 752 |
+
index=False)
|
| 753 |
+
|
| 754 |
# Build the markdown page
|
| 755 |
markdown_str = f"{eval_table}\n\n{predict_table}"
|
| 756 |
+
|
| 757 |
# Save the stats file
|
| 758 |
stats_file_name = f"{predictions_folder_name}/step_{step}.md"
|
| 759 |
with open(stats_file_name, "w") as f:
|
| 760 |
f.write(markdown_str)
|
| 761 |
+
|
| 762 |
# Create an header for all the files
|
| 763 |
+
md_files = sorted(os.path.basename(file) for file in os.listdir(
|
| 764 |
+
predictions_folder_name) if file.startswith("step_"))
|
| 765 |
+
sorted_md_files = sorted(
|
| 766 |
+
md_files, key=lambda x: int(x[0:-3].split("_")[1]))
|
| 767 |
+
md_header = " | ".join(
|
| 768 |
+
f"[Step {file[:-3].split('_')[1]}]({file})" for file in sorted_md_files)
|
| 769 |
|
| 770 |
# Add this header to all the stats file in the folder
|
| 771 |
for filename in os.listdir(predictions_folder_name):
|
| 772 |
if filename.startswith("step_"):
|
| 773 |
with open(os.path.join(predictions_folder_name, filename), "r+") as f:
|
| 774 |
content = f.read()
|
| 775 |
+
new_content = md_header + "\n\n" + \
|
| 776 |
+
content[content.index("| STEP| loss | wer"):]
|
| 777 |
f.seek(0)
|
| 778 |
f.write(new_content)
|
| 779 |
f.truncate()
|
| 780 |
+
|
| 781 |
# Add a folder for the mp3 files
|
| 782 |
+
if not os.path.exists(os.path.join(predictions_folder_name, "mp3")):
|
| 783 |
+
os.makedirs(os.path.join(predictions_folder_name, "mp3"))
|
| 784 |
+
|
| 785 |
for idx, mp3array in enumerate(eval_samples[0:data_args.number_write_predictions]):
|
| 786 |
+
audio_segment = AudioSegment(data=mp3array.tobytes(
|
| 787 |
+
), sample_width=mp3array.dtype.itemsize, frame_rate=16000, channels=1)
|
| 788 |
+
file_path = os.path.join(
|
| 789 |
+
predictions_folder_name, "mp3") + "/pred_" + str(idx) + ".mp3"
|
| 790 |
if not os.path.exists(file_path):
|
| 791 |
audio_segment.export(file_path, format="mp3")
|
| 792 |
+
|
| 793 |
+
logger.info(
|
| 794 |
+
f"Created {stats_file_name} and updated the headers of the other stats files")
|
| 795 |
+
|
| 796 |
# 9. Save feature extractor, tokenizer and config
|
| 797 |
feature_extractor.save_pretrained(training_args.output_dir)
|
| 798 |
tokenizer.save_pretrained(training_args.output_dir)
|
|
|
|
| 817 |
from flax.metrics.tensorboard import SummaryWriter
|
| 818 |
|
| 819 |
summary_writer = SummaryWriter(
|
| 820 |
+
log_dir=Path(os.path.join(training_args.output_dir, "events")))
|
| 821 |
except ImportError as ie:
|
| 822 |
has_tensorboard = False
|
| 823 |
logger.warning(
|
|
|
|
| 980 |
|
| 981 |
# Clean up the prediction folder if write_predictions is set to True
|
| 982 |
if data_args.number_write_predictions:
|
| 983 |
+
predictions_folder_name = os.path.join(
|
| 984 |
+
training_args.output_dir, "predictions")
|
| 985 |
shutil.rmtree(predictions_folder_name, ignore_errors=True)
|
| 986 |
os.makedirs(predictions_folder_name, exist_ok=True)
|
| 987 |
logger.info(f"Created folder {predictions_folder_name}")
|
| 988 |
+
|
| 989 |
# Create parallel version of the train and eval step
|
| 990 |
p_train_step = jax.pmap(
|
| 991 |
partial(train_step, label_smoothing_factor=training_args.label_smoothing_factor), "batch", donate_argnums=(0, )
|
|
|
|
| 1013 |
|
| 1014 |
train_metrics = []
|
| 1015 |
epoch = 0
|
| 1016 |
+
train_dataset = vectorized_datasets["train"].shuffle(
|
| 1017 |
+
seed=training_args.seed, buffer_size=data_args.shuffle_buffer_size)
|
| 1018 |
eval_dataset = vectorized_datasets["eval"]
|
| 1019 |
train_loader = data_loader(train_dataset, train_batch_size)
|
| 1020 |
# train
|
|
|
|
| 1044 |
eval_preds = []
|
| 1045 |
eval_labels = []
|
| 1046 |
eval_samples = []
|
| 1047 |
+
eval_loader = data_loader(
|
| 1048 |
+
eval_dataset, eval_batch_size, drop_last=False)
|
| 1049 |
if data_args.max_eval_samples:
|
| 1050 |
+
max_eval_steps_iter = range(
|
| 1051 |
+
1 + data_args.max_eval_samples // eval_batch_size)
|
| 1052 |
else:
|
| 1053 |
max_eval_steps_iter = itertools.repeat(None)
|
| 1054 |
for _ in tqdm(max_eval_steps_iter, desc="Evaluating...", position=2, leave=False):
|
|
|
|
| 1064 |
state.params, batch.data, min_device_batch=training_args.per_device_eval_batch_size
|
| 1065 |
)
|
| 1066 |
eval_metrics.append(metrics)
|
| 1067 |
+
if training_args.predict_with_generate and data_args.number_write_predictions and len(eval_samples) < data_args.number_write_predictions+eval_batch_size:
|
| 1068 |
eval_samples.extend(samples['input_features'])
|
| 1069 |
+
|
| 1070 |
# generation
|
| 1071 |
if training_args.predict_with_generate:
|
| 1072 |
generated_ids = pad_shard_unpad(
|
|
|
|
| 1095 |
if has_tensorboard and jax.process_index() == 0:
|
| 1096 |
write_metric(summary_writer, train_metrics,
|
| 1097 |
eval_metrics, train_time, step)
|
| 1098 |
+
|
| 1099 |
if training_args.predict_with_generate and data_args.number_write_predictions:
|
| 1100 |
+
write_predictions(step, eval_samples,
|
| 1101 |
+
eval_metrics, eval_preds, eval_labels)
|
| 1102 |
|
| 1103 |
# save checkpoint after each epoch and push checkpoint to the hub
|
| 1104 |
if jax.process_index() == 0:
|