File size: 4,555 Bytes
d596074
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
#!/usr/bin/env python3

# Copyright 2023 Johns Hopkins University (author: Dongji Gao)

import argparse
import random
from pathlib import Path
from typing import List

from lhotse import CutSet, load_manifest
from lhotse.cut.base import Cut

from icefall.utils import str2bool


def get_args():
    parser = argparse.ArgumentParser(
        formatter_class=argparse.ArgumentDefaultsHelpFormatter
    )
    parser.add_argument(
        "--input-cutset",
        type=str,
        help="Supervision manifest that contains verbatim transcript",
    )

    parser.add_argument(
        "--words-file",
        type=str,
        help="words.txt file",
    )

    parser.add_argument(
        "--otc-token",
        type=str,
        help="OTC token in words.txt",
    )

    parser.add_argument(
        "--sub-error-rate",
        type=float,
        default=0.0,
        help="Substitution error rate",
    )

    parser.add_argument(
        "--ins-error-rate",
        type=float,
        default=0.0,
        help="Insertion error rate",
    )

    parser.add_argument(
        "--del-error-rate",
        type=float,
        default=0.0,
        help="Deletion error rate",
    )

    parser.add_argument(
        "--output-cutset",
        type=str,
        default="",
        help="Supervision manifest that contains modified non-verbatim transcript",
    )

    parser.add_argument("--verbose", type=str2bool, help="show details of errors")
    return parser.parse_args()


def check_args(args):
    total_error_rate = args.sub_error_rate + args.ins_error_rate + args.del_error_rate
    assert args.sub_error_rate >= 0 and args.sub_error_rate <= 1.0
    assert args.ins_error_rate >= 0 and args.sub_error_rate <= 1.0
    assert args.del_error_rate >= 0 and args.sub_error_rate <= 1.0
    assert total_error_rate <= 1.0


def get_word_list(token_path: str) -> List:
    word_list = []
    with open(Path(token_path), "r") as tp:
        for line in tp.readlines():
            token = line.split()[0]
            assert token not in word_list
            word_list.append(token)
    return word_list


def modify_cut_text(
    cut: Cut,
    words_list: List,
    non_words: List,
    sub_ratio: float = 0.0,
    ins_ratio: float = 0.0,
    del_ratio: float = 0.0,
):
    text = cut.supervisions[0].text
    text_list = text.split()

    # We save the modified information of the original verbatim text for debugging
    marked_verbatim_text_list = []
    modified_text_list = []

    del_index_set = set()
    sub_index_set = set()
    ins_index_set = set()

    # We follow the order: deletion -> substitution -> insertion
    for token in text_list:
        marked_token = token
        modified_token = token

        prob = random.random()

        if prob <= del_ratio:
            marked_token = f"-{token}-"
            modified_token = ""
        elif prob <= del_ratio + sub_ratio + ins_ratio:
            if prob <= del_ratio + sub_ratio:
                marked_token = f"[{token}]"
            else:
                marked_verbatim_text_list.append(marked_token)
                modified_text_list.append(modified_token)
                marked_token = "[]"

            # get new_token
            while (
                modified_token == token
                or modified_token in non_words
                or modified_token.startswith("#")
            ):
                modified_token = random.choice(words_list)

        marked_verbatim_text_list.append(marked_token)
        modified_text_list.append(modified_token)

    marked_text = " ".join(marked_verbatim_text_list)
    modified_text = " ".join(modified_text_list)

    if not hasattr(cut.supervisions[0], "verbatim_text"):
        cut.supervisions[0].verbatim_text = marked_text
    cut.supervisions[0].text = modified_text

    return cut


def main():
    args = get_args()
    check_args(args)

    otc_token = args.otc_token
    non_words = set(("sil", "<UNK>", "<eps>"))
    non_words.add(otc_token)

    words_list = get_word_list(args.words_file)
    cutset = load_manifest(Path(args.input_cutset))

    cuts = []

    for cut in cutset:
        modified_cut = modify_cut_text(
            cut=cut,
            words_list=words_list,
            non_words=non_words,
            sub_ratio=args.sub_error_rate,
            ins_ratio=args.ins_error_rate,
            del_ratio=args.del_error_rate,
        )
        cuts.append(modified_cut)

    output_cutset = CutSet.from_cuts(cuts)
    output_cutset.to_file(args.output_cutset)


if __name__ == "__main__":
    main()