| import sys, os |
| import json |
| import string |
| from tqdm import tqdm |
|
|
| def process(text): |
|
|
| |
| text = text.lower() |
|
|
| |
| punctuation_to_remove = string.punctuation.replace("'", "") |
| translation_table = str.maketrans('', '', punctuation_to_remove) |
| text = text.translate(translation_table) |
|
|
| |
| while text[0] == ' ' or text[-1] == ' ': |
| if text[0] == ' ': |
| text = text[1:] |
| if text[-1] == ' ': |
| text = text[:-1] |
| |
| return text |
|
|
| split_name = "train.other.500" |
|
|
| with open("./blist/all_rare_words.txt") as fin: |
| rarewords = [process(word.strip()) for word in fin] |
|
|
| with open(f"./transcripts/{split_name}.txt") as fin: |
| transcripts = [line.strip() for line in fin] |
|
|
| from datasets import load_dataset |
|
|
| cache_dir = "./../cache" |
| dataset = load_dataset("openslr/librispeech_asr", cache_dir=cache_dir, trust_remote_code=True) |
|
|
| train_data = [] |
|
|
| pbar = tqdm(dataset[split_name]) |
| for idx, sample in enumerate(pbar): |
| |
| text = process(sample["text"]) |
| transcript = transcripts[idx] |
| |
| bwords = [] |
| for word in text.split(): |
| if word in rarewords and word not in transcript: |
| bwords.append(word) |
| |
| if len(bwords) > 0: |
| train_data.append({ |
| "split": split_name, |
| "idx": idx, |
| "text": text, |
| "transcript": transcript, |
| "b_words": bwords, |
| }) |
| pbar.set_description(f"Len of train data: {len(train_data)}") |
|
|
| with open(f"./train_data/{split_name}.json", "w") as fout: |
| json.dump(train_data, fout, indent=4) |