kimbochen commited on
Commit
cdee5d5
·
1 Parent(s): 2ef6c77

Added updated training script.

Browse files
run_speech_recognition_seq2seq_streaming.py CHANGED
@@ -479,6 +479,7 @@ def main():
479
  FULL2HALF = dict((i + 0xFEE0, i) for i in range(0x21, 0x7F))
480
  FULL2HALF[0x3000] = 0x20
481
 
 
482
  def prepare_dataset(batch):
483
  # process audio
484
  sample = batch[audio_column_name]
@@ -491,18 +492,20 @@ def main():
491
  input_str = batch[text_column_name].lower() if do_lower_case else batch[text_column_name]
492
  if do_remove_punctuation:
493
  input_str = normalizer(input_str).strip()
 
 
 
494
 
495
- input_str = input_str.translate(FULL2HALF)
496
- input_str = tagger.parse(input_str)
497
 
 
 
498
  batch["labels"] = tokenizer(input_str).input_ids
499
  return batch
500
 
 
501
  with training_args.main_process_first(desc="dataset map pre-processing"):
502
- vectorized_datasets = raw_datasets.map(
503
- prepare_dataset,
504
- remove_columns=raw_datasets_features,
505
- ).with_format("torch")
506
 
507
  if training_args.do_train and data_args.streaming:
508
  # manually shuffle if streaming (done by the trainer for non-streaming)
 
479
  FULL2HALF = dict((i + 0xFEE0, i) for i in range(0x21, 0x7F))
480
  FULL2HALF[0x3000] = 0x20
481
 
482
+
483
  def prepare_dataset(batch):
484
  # process audio
485
  sample = batch[audio_column_name]
 
492
  input_str = batch[text_column_name].lower() if do_lower_case else batch[text_column_name]
493
  if do_remove_punctuation:
494
  input_str = normalizer(input_str).strip()
495
+ batch['labels'] = input_str.translate(FULL2HALF)
496
+
497
+ return batch
498
 
 
 
499
 
500
+ def tokenize_labels(batch):
501
+ input_str = tagger.parse(batch['labels'])
502
  batch["labels"] = tokenizer(input_str).input_ids
503
  return batch
504
 
505
+
506
  with training_args.main_process_first(desc="dataset map pre-processing"):
507
+ prep_datasets = raw_datasets.map(prepare_dataset, remove_columns=raw_datasets_features)
508
+ vectorized_datasets = prep_datasets.map(tokenize_labels).with_format("torch")
 
 
509
 
510
  if training_args.do_train and data_args.streaming:
511
  # manually shuffle if streaming (done by the trainer for non-streaming)