| | |
| | |
| | |
| | |
| |
|
| | import argparse |
| | import itertools |
| | import logging |
| | import re |
| | import time |
| |
|
| | from g2p_en import G2p |
| |
|
| | logger = logging.getLogger(__name__) |
| |
|
| | FAIL_SENT = "FAILED_SENTENCE" |
| |
|
| |
|
| | def parse(): |
| | parser = argparse.ArgumentParser() |
| | parser.add_argument("--data-path", type=str, required=True) |
| | parser.add_argument("--out-path", type=str, required=True) |
| | parser.add_argument("--lower-case", action="store_true") |
| | parser.add_argument("--do-filter", action="store_true") |
| | parser.add_argument("--use-word-start", action="store_true") |
| | parser.add_argument("--dup-vowel", default=1, type=int) |
| | parser.add_argument("--dup-consonant", default=1, type=int) |
| | parser.add_argument("--no-punc", action="store_true") |
| | parser.add_argument("--reserve-word", type=str, default="") |
| | parser.add_argument( |
| | "--reserve-first-column", |
| | action="store_true", |
| | help="first column is sentence id", |
| | ) |
| | |
| | parser.add_argument("--parallel-process-num", default=1, type=int) |
| | parser.add_argument("--logdir", default="") |
| | args = parser.parse_args() |
| | return args |
| |
|
| |
|
| | def process_sent(sent, g2p, res_wrds, args): |
| | sents = pre_process_sent(sent, args.do_filter, args.lower_case, res_wrds) |
| | pho_seqs = [do_g2p(g2p, s, res_wrds, i == 0) for i, s in enumerate(sents)] |
| | pho_seq = ( |
| | [FAIL_SENT] |
| | if [FAIL_SENT] in pho_seqs |
| | else list(itertools.chain.from_iterable(pho_seqs)) |
| | ) |
| | if args.no_punc: |
| | pho_seq = remove_punc(pho_seq) |
| | if args.dup_vowel > 1 or args.dup_consonant > 1: |
| | pho_seq = dup_pho(pho_seq, args.dup_vowel, args.dup_consonant) |
| | if args.use_word_start: |
| | pho_seq = add_word_start(pho_seq) |
| | return " ".join(pho_seq) |
| |
|
| |
|
| | def remove_punc(sent): |
| | ns = [] |
| | regex = re.compile("[^a-zA-Z0-9 ]") |
| | for p in sent: |
| | if (not regex.search(p)) or p == FAIL_SENT: |
| | if p == " " and (len(ns) == 0 or ns[-1] == " "): |
| | continue |
| | ns.append(p) |
| | return ns |
| |
|
| |
|
| | def do_g2p(g2p, sent, res_wrds, is_first_sent): |
| | if sent in res_wrds: |
| | pho_seq = [res_wrds[sent]] |
| | else: |
| | pho_seq = g2p(sent) |
| | if not is_first_sent: |
| | pho_seq = [" "] + pho_seq |
| | return pho_seq |
| |
|
| |
|
| | def pre_process_sent(sent, do_filter, lower_case, res_wrds): |
| | if do_filter: |
| | sent = re.sub("-", " ", sent) |
| | sent = re.sub("—", " ", sent) |
| | if len(res_wrds) > 0: |
| | wrds = sent.split() |
| | wrds = ["SPLIT_ME " + w + " SPLIT_ME" if w in res_wrds else w for w in wrds] |
| | sents = [x.strip() for x in " ".join(wrds).split("SPLIT_ME") if x.strip() != ""] |
| | else: |
| | sents = [sent] |
| | if lower_case: |
| | sents = [s.lower() if s not in res_wrds else s for s in sents] |
| | return sents |
| |
|
| |
|
| | def dup_pho(sent, dup_v_num, dup_c_num): |
| | """ |
| | duplicate phoneme defined as cmudict |
| | http://www.speech.cs.cmu.edu/cgi-bin/cmudict |
| | """ |
| | if dup_v_num == 1 and dup_c_num == 1: |
| | return sent |
| | ns = [] |
| | for p in sent: |
| | ns.append(p) |
| | if re.search(r"\d$", p): |
| | for i in range(1, dup_v_num): |
| | ns.append(f"{p}-{i}P") |
| | elif re.search(r"\w", p): |
| | for i in range(1, dup_c_num): |
| | ns.append(f"{p}-{i}P") |
| | return ns |
| |
|
| |
|
| | def add_word_start(sent): |
| | ns = [] |
| | do_add = True |
| | ws = "▁" |
| | for p in sent: |
| | if do_add: |
| | p = ws + p |
| | do_add = False |
| | if p == " ": |
| | do_add = True |
| | else: |
| | ns.append(p) |
| | return ns |
| |
|
| |
|
| | def load_reserve_word(reserve_word): |
| | if reserve_word == "": |
| | return [] |
| | with open(reserve_word, "r") as fp: |
| | res_wrds = [x.strip().split() for x in fp.readlines() if x.strip() != ""] |
| | assert sum([0 if len(x) == 2 else 1 for x in res_wrds]) == 0 |
| | res_wrds = dict(res_wrds) |
| | return res_wrds |
| |
|
| |
|
| | def process_sents(sents, args): |
| | g2p = G2p() |
| | out_sents = [] |
| | res_wrds = load_reserve_word(args.reserve_word) |
| | for sent in sents: |
| | col1 = "" |
| | if args.reserve_first_column: |
| | col1, sent = sent.split(None, 1) |
| | sent = process_sent(sent, g2p, res_wrds, args) |
| | if args.reserve_first_column and col1 != "": |
| | sent = f"{col1} {sent}" |
| | out_sents.append(sent) |
| | return out_sents |
| |
|
| |
|
| | def main(): |
| | args = parse() |
| | out_sents = [] |
| | with open(args.data_path, "r") as fp: |
| | sent_list = [x.strip() for x in fp.readlines()] |
| | if args.parallel_process_num > 1: |
| | try: |
| | import submitit |
| | except ImportError: |
| | logger.warn( |
| | "submitit is not found and only one job is used to process the data" |
| | ) |
| | submitit = None |
| |
|
| | if args.parallel_process_num == 1 or submitit is None: |
| | out_sents = process_sents(sent_list, args) |
| | else: |
| | |
| | lsize = len(sent_list) // args.parallel_process_num + 1 |
| | executor = submitit.AutoExecutor(folder=args.logdir) |
| | executor.update_parameters(timeout_min=1000, cpus_per_task=4) |
| | jobs = [] |
| | for i in range(args.parallel_process_num): |
| | job = executor.submit( |
| | process_sents, sent_list[lsize * i : lsize * (i + 1)], args |
| | ) |
| | jobs.append(job) |
| | is_running = True |
| | while is_running: |
| | time.sleep(5) |
| | is_running = sum([job.done() for job in jobs]) < len(jobs) |
| | out_sents = list(itertools.chain.from_iterable([job.result() for job in jobs])) |
| | with open(args.out_path, "w") as fp: |
| | fp.write("\n".join(out_sents) + "\n") |
| |
|
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|