File size: 4,555 Bytes
d596074 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 |
#!/usr/bin/env python3
# Copyright 2023 Johns Hopkins University (author: Dongji Gao)
import argparse
import random
from pathlib import Path
from typing import List
from lhotse import CutSet, load_manifest
from lhotse.cut.base import Cut
from icefall.utils import str2bool
def get_args():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--input-cutset",
type=str,
help="Supervision manifest that contains verbatim transcript",
)
parser.add_argument(
"--words-file",
type=str,
help="words.txt file",
)
parser.add_argument(
"--otc-token",
type=str,
help="OTC token in words.txt",
)
parser.add_argument(
"--sub-error-rate",
type=float,
default=0.0,
help="Substitution error rate",
)
parser.add_argument(
"--ins-error-rate",
type=float,
default=0.0,
help="Insertion error rate",
)
parser.add_argument(
"--del-error-rate",
type=float,
default=0.0,
help="Deletion error rate",
)
parser.add_argument(
"--output-cutset",
type=str,
default="",
help="Supervision manifest that contains modified non-verbatim transcript",
)
parser.add_argument("--verbose", type=str2bool, help="show details of errors")
return parser.parse_args()
def check_args(args):
total_error_rate = args.sub_error_rate + args.ins_error_rate + args.del_error_rate
assert args.sub_error_rate >= 0 and args.sub_error_rate <= 1.0
assert args.ins_error_rate >= 0 and args.sub_error_rate <= 1.0
assert args.del_error_rate >= 0 and args.sub_error_rate <= 1.0
assert total_error_rate <= 1.0
def get_word_list(token_path: str) -> List:
word_list = []
with open(Path(token_path), "r") as tp:
for line in tp.readlines():
token = line.split()[0]
assert token not in word_list
word_list.append(token)
return word_list
def modify_cut_text(
cut: Cut,
words_list: List,
non_words: List,
sub_ratio: float = 0.0,
ins_ratio: float = 0.0,
del_ratio: float = 0.0,
):
text = cut.supervisions[0].text
text_list = text.split()
# We save the modified information of the original verbatim text for debugging
marked_verbatim_text_list = []
modified_text_list = []
del_index_set = set()
sub_index_set = set()
ins_index_set = set()
# We follow the order: deletion -> substitution -> insertion
for token in text_list:
marked_token = token
modified_token = token
prob = random.random()
if prob <= del_ratio:
marked_token = f"-{token}-"
modified_token = ""
elif prob <= del_ratio + sub_ratio + ins_ratio:
if prob <= del_ratio + sub_ratio:
marked_token = f"[{token}]"
else:
marked_verbatim_text_list.append(marked_token)
modified_text_list.append(modified_token)
marked_token = "[]"
# get new_token
while (
modified_token == token
or modified_token in non_words
or modified_token.startswith("#")
):
modified_token = random.choice(words_list)
marked_verbatim_text_list.append(marked_token)
modified_text_list.append(modified_token)
marked_text = " ".join(marked_verbatim_text_list)
modified_text = " ".join(modified_text_list)
if not hasattr(cut.supervisions[0], "verbatim_text"):
cut.supervisions[0].verbatim_text = marked_text
cut.supervisions[0].text = modified_text
return cut
def main():
args = get_args()
check_args(args)
otc_token = args.otc_token
non_words = set(("sil", "<UNK>", "<eps>"))
non_words.add(otc_token)
words_list = get_word_list(args.words_file)
cutset = load_manifest(Path(args.input_cutset))
cuts = []
for cut in cutset:
modified_cut = modify_cut_text(
cut=cut,
words_list=words_list,
non_words=non_words,
sub_ratio=args.sub_error_rate,
ins_ratio=args.ins_error_rate,
del_ratio=args.del_error_rate,
)
cuts.append(modified_cut)
output_cutset = CutSet.from_cuts(cuts)
output_cutset.to_file(args.output_cutset)
if __name__ == "__main__":
main()
|