Added updated training script.
Browse files
run_speech_recognition_seq2seq_streaming.py
CHANGED
|
@@ -479,6 +479,7 @@ def main():
|
|
| 479 |
FULL2HALF = dict((i + 0xFEE0, i) for i in range(0x21, 0x7F))
|
| 480 |
FULL2HALF[0x3000] = 0x20
|
| 481 |
|
|
|
|
| 482 |
def prepare_dataset(batch):
|
| 483 |
# process audio
|
| 484 |
sample = batch[audio_column_name]
|
|
@@ -491,18 +492,20 @@ def main():
|
|
| 491 |
input_str = batch[text_column_name].lower() if do_lower_case else batch[text_column_name]
|
| 492 |
if do_remove_punctuation:
|
| 493 |
input_str = normalizer(input_str).strip()
|
|
|
|
|
|
|
|
|
|
| 494 |
|
| 495 |
-
input_str = input_str.translate(FULL2HALF)
|
| 496 |
-
input_str = tagger.parse(input_str)
|
| 497 |
|
|
|
|
|
|
|
| 498 |
batch["labels"] = tokenizer(input_str).input_ids
|
| 499 |
return batch
|
| 500 |
|
|
|
|
| 501 |
with training_args.main_process_first(desc="dataset map pre-processing"):
|
| 502 |
-
|
| 503 |
-
|
| 504 |
-
remove_columns=raw_datasets_features,
|
| 505 |
-
).with_format("torch")
|
| 506 |
|
| 507 |
if training_args.do_train and data_args.streaming:
|
| 508 |
# manually shuffle if streaming (done by the trainer for non-streaming)
|
|
|
|
| 479 |
FULL2HALF = dict((i + 0xFEE0, i) for i in range(0x21, 0x7F))
|
| 480 |
FULL2HALF[0x3000] = 0x20
|
| 481 |
|
| 482 |
+
|
| 483 |
def prepare_dataset(batch):
|
| 484 |
# process audio
|
| 485 |
sample = batch[audio_column_name]
|
|
|
|
| 492 |
input_str = batch[text_column_name].lower() if do_lower_case else batch[text_column_name]
|
| 493 |
if do_remove_punctuation:
|
| 494 |
input_str = normalizer(input_str).strip()
|
| 495 |
+
batch['labels'] = input_str.translate(FULL2HALF)
|
| 496 |
+
|
| 497 |
+
return batch
|
| 498 |
|
|
|
|
|
|
|
| 499 |
|
| 500 |
+
def tokenize_labels(batch):
|
| 501 |
+
input_str = tagger.parse(batch['labels'])
|
| 502 |
batch["labels"] = tokenizer(input_str).input_ids
|
| 503 |
return batch
|
| 504 |
|
| 505 |
+
|
| 506 |
with training_args.main_process_first(desc="dataset map pre-processing"):
|
| 507 |
+
prep_datasets = raw_datasets.map(prepare_dataset, remove_columns=raw_datasets_features)
|
| 508 |
+
vectorized_datasets = prep_datasets.map(tokenize_labels).with_format("torch")
|
|
|
|
|
|
|
| 509 |
|
| 510 |
if training_args.do_train and data_args.streaming:
|
| 511 |
# manually shuffle if streaming (done by the trainer for non-streaming)
|