pere commited on
Commit
ae49c39
·
1 Parent(s): 92a6382
run_flax_speech_recognition_seq2seq_streaming.py CHANGED
@@ -66,7 +66,7 @@ from transformers.utils import check_min_version, send_example_telemetry
66
  from transformers.utils.versions import require_version
67
 
68
  import pandas as pd
69
-
70
  # Will error if the minimal version of Transformers is not installed. Remove at your own risks.
71
  check_min_version("4.27.0.dev0")
72
 
@@ -417,8 +417,8 @@ def collate_batch(samples):
417
  def data_loader(
418
  dataset: Dataset,
419
  batch_size: int,
420
- drop_last: bool=True,
421
- num_workers: int=0,
422
  ) -> Generator:
423
  """
424
  Returns batches of size `batch_size` from `dataset`. If `drop_last` is set to `False`, the final batch may be incomplete,
@@ -455,7 +455,7 @@ def write_metric(summary_writer, train_metrics, eval_metrics, train_time, step):
455
 
456
 
457
  def create_learning_rate_fn(
458
- num_train_steps: int, num_warmup_steps: int, learning_rate: float, warmup_init_value: float=0.0, decay_end_value: float=0.0,
459
  ) -> Callable[[int], jnp.array]:
460
  """Returns a linear warmup, linear_decay learning rate function."""
461
  warmup_fn = optax.linear_schedule(
@@ -525,12 +525,15 @@ def main():
525
  if training_args.push_to_hub:
526
  if training_args.hub_model_id is None:
527
  repo_name = get_full_repo_name(
528
- Path(training_args.output_dir).absolute(
529
- ).name, token=training_args.hub_token
 
530
  )
531
  else:
532
  repo_name = training_args.hub_model_id
533
- create_repo(repo_name, exist_ok=True, token=training_args.hub_token)
 
 
534
  repo = Repository(training_args.output_dir,
535
  clone_from=repo_name, token=training_args.hub_token)
536
 
@@ -710,72 +713,86 @@ def main():
710
  label_str = [label_str[i]
711
  for i in range(len(label_str)) if len(label_str[i]) > 0]
712
 
713
- wer = 100 * metric_wer.compute(predictions=pred_str, references=label_str)
714
- cer = 100 * metric_cer.compute(predictions=pred_str, references=label_str)
 
 
715
 
716
  return {"wer": wer, "cer": cer}
717
 
718
  # TODO: Sanchit. This part is not essential and a bit verbose.
719
  # We do howver think it is extemely useful info for debugging and improving the model
720
-
721
- def write_predictions(step,eval_samples,eval_metrics, pred_ids, label_ids):
722
- predictions_folder_name = os.path.join(training_args.output_dir,"predictions")
 
723
  eval_table = f"| STEP| loss | wer |cer|\n| ---| --- | --- |--- |\n| **{step}**| {eval_metrics['loss']:.3f} | {eval_metrics['wer']:.3f} |{eval_metrics['cer']:.3f} |"
724
-
725
  # Put all predictions into a table
726
- inference_df = pd.DataFrame(columns=['mp3','target','prediction'])
727
-
728
  idx = 0
729
- for pred,label in zip(pred_ids,label_ids):
730
- pred_text = tokenizer.decode(pred,skip_special_tokens=True)
731
- label_text = tokenizer.decode(label,skip_special_tokens=True)
732
- formatted_pred_text = " ".join([f"**{word}**" if word in label_text else f"*{word}*" for word in pred_text.split()])
 
733
  #wer = 100 * metric_wer.compute(predictions=[pred_text], references=[label_text])
734
  #wer_formatted = f'{wer:.2f}'
735
  audio_control = f'[↓]({"mp3/pred_"+str(idx)+".mp3"})'
736
- new_row = pd.DataFrame({'mp3': audio_control, 'target': label_text, 'prediction': formatted_pred_text}, index=[0])
737
- inference_df = pd.concat([inference_df, new_row], ignore_index=True)
 
 
738
  idx += 1
739
-
740
  # Create the prediction table of the first N rows
741
- inference_df = inference_df[['mp3', 'target','prediction']]
742
- predict_table = inference_df[0:data_args.number_write_predictions].to_markdown(index=False)
743
-
 
744
  # Build the markdown page
745
  markdown_str = f"{eval_table}\n\n{predict_table}"
746
-
747
  # Save the stats file
748
  stats_file_name = f"{predictions_folder_name}/step_{step}.md"
749
  with open(stats_file_name, "w") as f:
750
  f.write(markdown_str)
751
-
752
  # Create an header for all the files
753
- md_files = sorted(os.path.basename(file) for file in os.listdir(predictions_folder_name) if file.startswith("step_"))
754
- sorted_md_files = sorted(md_files, key=lambda x: int(x[0:-3].split("_")[1]))
755
- md_header = " | ".join(f"[Step {file[:-3].split('_')[1]}]({file})" for file in sorted_md_files)
 
 
 
756
 
757
  # Add this header to all the stats file in the folder
758
  for filename in os.listdir(predictions_folder_name):
759
  if filename.startswith("step_"):
760
  with open(os.path.join(predictions_folder_name, filename), "r+") as f:
761
  content = f.read()
762
- new_content = md_header + "\n\n" + content[content.index("| STEP| loss | wer"):]
 
763
  f.seek(0)
764
  f.write(new_content)
765
  f.truncate()
766
-
767
  # Add a folder for the mp3 files
768
- if not os.path.exists(os.path.join(predictions_folder_name,"mp3")):
769
- os.makedirs(os.path.join(predictions_folder_name,"mp3"))
770
-
771
  for idx, mp3array in enumerate(eval_samples[0:data_args.number_write_predictions]):
772
- audio_segment = AudioSegment(data=mp3array.tobytes(), sample_width=mp3array.dtype.itemsize, frame_rate=16000, channels=1)
773
- file_path = os.path.join(predictions_folder_name, "mp3") + "/pred_" + str(idx) + ".mp3"
 
 
774
  if not os.path.exists(file_path):
775
  audio_segment.export(file_path, format="mp3")
776
-
777
- logger.info(f"Created {stats_file_name} and updated the headers of the other stats files")
778
-
 
779
  # 9. Save feature extractor, tokenizer and config
780
  feature_extractor.save_pretrained(training_args.output_dir)
781
  tokenizer.save_pretrained(training_args.output_dir)
@@ -800,7 +817,7 @@ def main():
800
  from flax.metrics.tensorboard import SummaryWriter
801
 
802
  summary_writer = SummaryWriter(
803
- log_dir=Path(os.path.join(training_args.output_dir,"events")))
804
  except ImportError as ie:
805
  has_tensorboard = False
806
  logger.warning(
@@ -963,11 +980,12 @@ def main():
963
 
964
  # Clean up the prediction folder if write_predictions is set to True
965
  if data_args.number_write_predictions:
966
- predictions_folder_name = os.path.join(training_args.output_dir, "predictions")
 
967
  shutil.rmtree(predictions_folder_name, ignore_errors=True)
968
  os.makedirs(predictions_folder_name, exist_ok=True)
969
  logger.info(f"Created folder {predictions_folder_name}")
970
-
971
  # Create parallel version of the train and eval step
972
  p_train_step = jax.pmap(
973
  partial(train_step, label_smoothing_factor=training_args.label_smoothing_factor), "batch", donate_argnums=(0, )
@@ -995,7 +1013,8 @@ def main():
995
 
996
  train_metrics = []
997
  epoch = 0
998
- train_dataset = vectorized_datasets["train"].shuffle(seed=training_args.seed, buffer_size=data_args.shuffle_buffer_size)
 
999
  eval_dataset = vectorized_datasets["eval"]
1000
  train_loader = data_loader(train_dataset, train_batch_size)
1001
  # train
@@ -1025,9 +1044,11 @@ def main():
1025
  eval_preds = []
1026
  eval_labels = []
1027
  eval_samples = []
1028
- eval_loader = data_loader(eval_dataset, eval_batch_size, drop_last=False)
 
1029
  if data_args.max_eval_samples:
1030
- max_eval_steps_iter = range(1 + data_args.max_eval_samples // eval_batch_size)
 
1031
  else:
1032
  max_eval_steps_iter = itertools.repeat(None)
1033
  for _ in tqdm(max_eval_steps_iter, desc="Evaluating...", position=2, leave=False):
@@ -1043,9 +1064,9 @@ def main():
1043
  state.params, batch.data, min_device_batch=training_args.per_device_eval_batch_size
1044
  )
1045
  eval_metrics.append(metrics)
1046
- if training_args.predict_with_generate and data_args.number_write_predictions and len(eval_samples)<data_args.number_write_predictions+eval_batch_size:
1047
  eval_samples.extend(samples['input_features'])
1048
-
1049
  # generation
1050
  if training_args.predict_with_generate:
1051
  generated_ids = pad_shard_unpad(
@@ -1074,9 +1095,10 @@ def main():
1074
  if has_tensorboard and jax.process_index() == 0:
1075
  write_metric(summary_writer, train_metrics,
1076
  eval_metrics, train_time, step)
1077
-
1078
  if training_args.predict_with_generate and data_args.number_write_predictions:
1079
- write_predictions(step,eval_samples,eval_metrics, eval_preds, eval_labels)
 
1080
 
1081
  # save checkpoint after each epoch and push checkpoint to the hub
1082
  if jax.process_index() == 0:
 
66
  from transformers.utils.versions import require_version
67
 
68
  import pandas as pd
69
+
70
  # Will error if the minimal version of Transformers is not installed. Remove at your own risks.
71
  check_min_version("4.27.0.dev0")
72
 
 
417
  def data_loader(
418
  dataset: Dataset,
419
  batch_size: int,
420
+ drop_last: bool = True,
421
+ num_workers: int = 0,
422
  ) -> Generator:
423
  """
424
  Returns batches of size `batch_size` from `dataset`. If `drop_last` is set to `False`, the final batch may be incomplete,
 
455
 
456
 
457
  def create_learning_rate_fn(
458
+ num_train_steps: int, num_warmup_steps: int, learning_rate: float, warmup_init_value: float = 0.0, decay_end_value: float = 0.0,
459
  ) -> Callable[[int], jnp.array]:
460
  """Returns a linear warmup, linear_decay learning rate function."""
461
  warmup_fn = optax.linear_schedule(
 
525
  if training_args.push_to_hub:
526
  if training_args.hub_model_id is None:
527
  repo_name = get_full_repo_name(
528
+ Path(training_args.output_dir).absolute().name,
529
+ token=training_args.hub_token,
530
+ organization=training_args.push_to_hub_organization,
531
  )
532
  else:
533
  repo_name = training_args.hub_model_id
534
+ create_repo(
535
+ repo_name, exist_ok=True, token=training_args.hub_token, private=training_args.hub_private_repo
536
+ )
537
  repo = Repository(training_args.output_dir,
538
  clone_from=repo_name, token=training_args.hub_token)
539
 
 
713
  label_str = [label_str[i]
714
  for i in range(len(label_str)) if len(label_str[i]) > 0]
715
 
716
+ wer = 100 * \
717
+ metric_wer.compute(predictions=pred_str, references=label_str)
718
+ cer = 100 * \
719
+ metric_cer.compute(predictions=pred_str, references=label_str)
720
 
721
  return {"wer": wer, "cer": cer}
722
 
723
  # TODO: Sanchit. This part is not essential and a bit verbose.
724
  # We do howver think it is extemely useful info for debugging and improving the model
725
+
726
+ def write_predictions(step, eval_samples, eval_metrics, pred_ids, label_ids):
727
+ predictions_folder_name = os.path.join(
728
+ training_args.output_dir, "predictions")
729
  eval_table = f"| STEP| loss | wer |cer|\n| ---| --- | --- |--- |\n| **{step}**| {eval_metrics['loss']:.3f} | {eval_metrics['wer']:.3f} |{eval_metrics['cer']:.3f} |"
730
+
731
  # Put all predictions into a table
732
+ inference_df = pd.DataFrame(columns=['mp3', 'target', 'prediction'])
733
+
734
  idx = 0
735
+ for pred, label in zip(pred_ids, label_ids):
736
+ pred_text = tokenizer.decode(pred, skip_special_tokens=True)
737
+ label_text = tokenizer.decode(label, skip_special_tokens=True)
738
+ formatted_pred_text = " ".join(
739
+ [f"**{word}**" if word in label_text else f"*{word}*" for word in pred_text.split()])
740
  #wer = 100 * metric_wer.compute(predictions=[pred_text], references=[label_text])
741
  #wer_formatted = f'{wer:.2f}'
742
  audio_control = f'[↓]({"mp3/pred_"+str(idx)+".mp3"})'
743
+ new_row = pd.DataFrame(
744
+ {'mp3': audio_control, 'target': label_text, 'prediction': formatted_pred_text}, index=[0])
745
+ inference_df = pd.concat(
746
+ [inference_df, new_row], ignore_index=True)
747
  idx += 1
748
+
749
  # Create the prediction table of the first N rows
750
+ inference_df = inference_df[['mp3', 'target', 'prediction']]
751
+ predict_table = inference_df[0:data_args.number_write_predictions].to_markdown(
752
+ index=False)
753
+
754
  # Build the markdown page
755
  markdown_str = f"{eval_table}\n\n{predict_table}"
756
+
757
  # Save the stats file
758
  stats_file_name = f"{predictions_folder_name}/step_{step}.md"
759
  with open(stats_file_name, "w") as f:
760
  f.write(markdown_str)
761
+
762
  # Create an header for all the files
763
+ md_files = sorted(os.path.basename(file) for file in os.listdir(
764
+ predictions_folder_name) if file.startswith("step_"))
765
+ sorted_md_files = sorted(
766
+ md_files, key=lambda x: int(x[0:-3].split("_")[1]))
767
+ md_header = " | ".join(
768
+ f"[Step {file[:-3].split('_')[1]}]({file})" for file in sorted_md_files)
769
 
770
  # Add this header to all the stats file in the folder
771
  for filename in os.listdir(predictions_folder_name):
772
  if filename.startswith("step_"):
773
  with open(os.path.join(predictions_folder_name, filename), "r+") as f:
774
  content = f.read()
775
+ new_content = md_header + "\n\n" + \
776
+ content[content.index("| STEP| loss | wer"):]
777
  f.seek(0)
778
  f.write(new_content)
779
  f.truncate()
780
+
781
  # Add a folder for the mp3 files
782
+ if not os.path.exists(os.path.join(predictions_folder_name, "mp3")):
783
+ os.makedirs(os.path.join(predictions_folder_name, "mp3"))
784
+
785
  for idx, mp3array in enumerate(eval_samples[0:data_args.number_write_predictions]):
786
+ audio_segment = AudioSegment(data=mp3array.tobytes(
787
+ ), sample_width=mp3array.dtype.itemsize, frame_rate=16000, channels=1)
788
+ file_path = os.path.join(
789
+ predictions_folder_name, "mp3") + "/pred_" + str(idx) + ".mp3"
790
  if not os.path.exists(file_path):
791
  audio_segment.export(file_path, format="mp3")
792
+
793
+ logger.info(
794
+ f"Created {stats_file_name} and updated the headers of the other stats files")
795
+
796
  # 9. Save feature extractor, tokenizer and config
797
  feature_extractor.save_pretrained(training_args.output_dir)
798
  tokenizer.save_pretrained(training_args.output_dir)
 
817
  from flax.metrics.tensorboard import SummaryWriter
818
 
819
  summary_writer = SummaryWriter(
820
+ log_dir=Path(os.path.join(training_args.output_dir, "events")))
821
  except ImportError as ie:
822
  has_tensorboard = False
823
  logger.warning(
 
980
 
981
  # Clean up the prediction folder if write_predictions is set to True
982
  if data_args.number_write_predictions:
983
+ predictions_folder_name = os.path.join(
984
+ training_args.output_dir, "predictions")
985
  shutil.rmtree(predictions_folder_name, ignore_errors=True)
986
  os.makedirs(predictions_folder_name, exist_ok=True)
987
  logger.info(f"Created folder {predictions_folder_name}")
988
+
989
  # Create parallel version of the train and eval step
990
  p_train_step = jax.pmap(
991
  partial(train_step, label_smoothing_factor=training_args.label_smoothing_factor), "batch", donate_argnums=(0, )
 
1013
 
1014
  train_metrics = []
1015
  epoch = 0
1016
+ train_dataset = vectorized_datasets["train"].shuffle(
1017
+ seed=training_args.seed, buffer_size=data_args.shuffle_buffer_size)
1018
  eval_dataset = vectorized_datasets["eval"]
1019
  train_loader = data_loader(train_dataset, train_batch_size)
1020
  # train
 
1044
  eval_preds = []
1045
  eval_labels = []
1046
  eval_samples = []
1047
+ eval_loader = data_loader(
1048
+ eval_dataset, eval_batch_size, drop_last=False)
1049
  if data_args.max_eval_samples:
1050
+ max_eval_steps_iter = range(
1051
+ 1 + data_args.max_eval_samples // eval_batch_size)
1052
  else:
1053
  max_eval_steps_iter = itertools.repeat(None)
1054
  for _ in tqdm(max_eval_steps_iter, desc="Evaluating...", position=2, leave=False):
 
1064
  state.params, batch.data, min_device_batch=training_args.per_device_eval_batch_size
1065
  )
1066
  eval_metrics.append(metrics)
1067
+ if training_args.predict_with_generate and data_args.number_write_predictions and len(eval_samples) < data_args.number_write_predictions+eval_batch_size:
1068
  eval_samples.extend(samples['input_features'])
1069
+
1070
  # generation
1071
  if training_args.predict_with_generate:
1072
  generated_ids = pad_shard_unpad(
 
1095
  if has_tensorboard and jax.process_index() == 0:
1096
  write_metric(summary_writer, train_metrics,
1097
  eval_metrics, train_time, step)
1098
+
1099
  if training_args.predict_with_generate and data_args.number_write_predictions:
1100
+ write_predictions(step, eval_samples,
1101
+ eval_metrics, eval_preds, eval_labels)
1102
 
1103
  # save checkpoint after each epoch and push checkpoint to the hub
1104
  if jax.process_index() == 0: