| | |
| |
|
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | from __future__ import print_function |
| |
|
| | import argparse |
| | import copy |
| | import logging |
| | import os |
| | import sys |
| |
|
| | import torch |
| | import yaml |
| | from torch.utils.data import DataLoader |
| | from textgrid import TextGrid, IntervalTier |
| |
|
| | from wenet.dataset.dataset import Dataset |
| | from wenet.utils.checkpoint import load_checkpoint |
| | from wenet.utils.file_utils import read_symbol_table, read_non_lang_symbols |
| | from wenet.utils.ctc_util import forced_align |
| | from wenet.utils.common import get_subsample |
| | from wenet.utils.init_model import init_model |
| |
|
| |
|
| | def generator_textgrid(maxtime, lines, output): |
| | |
| | interval = maxtime / (len(lines) + 1) |
| | margin = 0.0001 |
| |
|
| | tg = TextGrid(maxTime=maxtime) |
| | linetier = IntervalTier(name="line", maxTime=maxtime) |
| |
|
| | i = 0 |
| | for l in lines: |
| | s, e, w = l.split() |
| | linetier.add(minTime=float(s) + margin, maxTime=float(e), mark=w) |
| |
|
| | tg.append(linetier) |
| | print("successfully generator {}".format(output)) |
| | tg.write(output) |
| |
|
| |
|
| | def get_frames_timestamp(alignment): |
| | |
| | |
| | timestamp = [] |
| | |
| | start = 0 |
| | end = 0 |
| | while end < len(alignment): |
| | while end < len(alignment) and alignment[end] == 0: |
| | end += 1 |
| | if end == len(alignment): |
| | timestamp[-1] += alignment[start:] |
| | break |
| | end += 1 |
| | while end < len(alignment) and alignment[end - 1] == alignment[end]: |
| | end += 1 |
| | timestamp.append(alignment[start:end]) |
| | start = end |
| | return timestamp |
| |
|
| |
|
| | def get_labformat(timestamp, subsample): |
| | begin = 0 |
| | duration = 0 |
| | labformat = [] |
| | for idx, t in enumerate(timestamp): |
| | |
| | subsample = get_subsample(configs) |
| | |
| | duration = len(t) * 0.01 * subsample |
| | if idx < len(timestamp) - 1: |
| | print("{:.2f} {:.2f} {}".format(begin, begin + duration, char_dict[t[-1]])) |
| | labformat.append( |
| | "{:.2f} {:.2f} {}\n".format(begin, begin + duration, char_dict[t[-1]]) |
| | ) |
| | else: |
| | non_blank = 0 |
| | for i in t: |
| | if i != 0: |
| | token = i |
| | break |
| | print("{:.2f} {:.2f} {}".format(begin, begin + duration, char_dict[token])) |
| | labformat.append( |
| | "{:.2f} {:.2f} {}\n".format(begin, begin + duration, char_dict[token]) |
| | ) |
| | begin = begin + duration |
| | return labformat |
| |
|
| |
|
| | if __name__ == "__main__": |
| | parser = argparse.ArgumentParser(description="use ctc to generate alignment") |
| | parser.add_argument("--config", required=True, help="config file") |
| | parser.add_argument("--input_file", required=True, help="format data file") |
| | parser.add_argument( |
| | "--data_type", |
| | default="raw", |
| | choices=["raw", "shard"], |
| | help="train and cv data type", |
| | ) |
| | parser.add_argument( |
| | "--gpu", type=int, default=-1, help="gpu id for this rank, -1 for cpu" |
| | ) |
| | parser.add_argument("--checkpoint", required=True, help="checkpoint model") |
| | parser.add_argument("--dict", required=True, help="dict file") |
| | parser.add_argument( |
| | "--non_lang_syms", help="non-linguistic symbol file. One symbol per line." |
| | ) |
| | parser.add_argument("--result_file", required=True, help="alignment result file") |
| | parser.add_argument("--batch_size", type=int, default=1, help="batch size") |
| | parser.add_argument( |
| | "--gen_praat", action="store_true", help="convert alignment to a praat format" |
| | ) |
| | parser.add_argument( |
| | "--bpe_model", default=None, type=str, help="bpe model for english part" |
| | ) |
| |
|
| | args = parser.parse_args() |
| | print(args) |
| | logging.basicConfig( |
| | level=logging.DEBUG, format="%(asctime)s %(levelname)s %(message)s" |
| | ) |
| | os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu) |
| |
|
| | if args.batch_size > 1: |
| | logging.fatal("alignment mode must be running with batch_size == 1") |
| | sys.exit(1) |
| |
|
| | with open(args.config, "r") as fin: |
| | configs = yaml.load(fin, Loader=yaml.FullLoader) |
| |
|
| | |
| | char_dict = {} |
| | with open(args.dict, "r") as fin: |
| | for line in fin: |
| | arr = line.strip().split() |
| | assert len(arr) == 2 |
| | char_dict[int(arr[1])] = arr[0] |
| | eos = len(char_dict) - 1 |
| |
|
| | symbol_table = read_symbol_table(args.dict) |
| |
|
| | |
| | ali_conf = copy.deepcopy(configs["dataset_conf"]) |
| |
|
| | ali_conf["filter_conf"]["max_length"] = 102400 |
| | ali_conf["filter_conf"]["min_length"] = 0 |
| | ali_conf["filter_conf"]["token_max_length"] = 102400 |
| | ali_conf["filter_conf"]["token_min_length"] = 0 |
| | ali_conf["filter_conf"]["max_output_input_ratio"] = 102400 |
| | ali_conf["filter_conf"]["min_output_input_ratio"] = 0 |
| | ali_conf["speed_perturb"] = False |
| | ali_conf["spec_aug"] = False |
| | ali_conf["shuffle"] = False |
| | ali_conf["sort"] = False |
| | ali_conf["fbank_conf"]["dither"] = 0.0 |
| | ali_conf["batch_conf"]["batch_type"] = "static" |
| | ali_conf["batch_conf"]["batch_size"] = args.batch_size |
| | non_lang_syms = read_non_lang_symbols(args.non_lang_syms) |
| |
|
| | ali_dataset = Dataset( |
| | args.data_type, |
| | args.input_file, |
| | symbol_table, |
| | ali_conf, |
| | args.bpe_model, |
| | non_lang_syms, |
| | partition=False, |
| | ) |
| |
|
| | ali_data_loader = DataLoader(ali_dataset, batch_size=None, num_workers=0) |
| |
|
| | |
| | model = init_model(configs) |
| |
|
| | load_checkpoint(model, args.checkpoint) |
| | use_cuda = args.gpu >= 0 and torch.cuda.is_available() |
| | device = torch.device("cuda" if use_cuda else "cpu") |
| | model = model.to(device) |
| |
|
| | model.eval() |
| | with torch.no_grad(), open(args.result_file, "w", encoding="utf-8") as fout: |
| | for batch_idx, batch in enumerate(ali_data_loader): |
| | print("#" * 80) |
| | key, feat, target, feats_length, target_length = batch |
| | print(key) |
| |
|
| | feat = feat.to(device) |
| | target = target.to(device) |
| | feats_length = feats_length.to(device) |
| | target_length = target_length.to(device) |
| | |
| | |
| | encoder_out, encoder_mask = model._forward_encoder( |
| | feat, feats_length |
| | ) |
| | maxlen = encoder_out.size(1) |
| | ctc_probs = model.ctc.log_softmax(encoder_out) |
| | |
| | ctc_probs = ctc_probs.squeeze(0) |
| | target = target.squeeze(0) |
| | alignment = forced_align(ctc_probs, target) |
| | print(alignment) |
| | fout.write("{} {}\n".format(key[0], alignment)) |
| |
|
| | if args.gen_praat: |
| | timestamp = get_frames_timestamp(alignment) |
| | print(timestamp) |
| | subsample = get_subsample(configs) |
| | labformat = get_labformat(timestamp, subsample) |
| |
|
| | lab_path = os.path.join( |
| | os.path.dirname(args.result_file), key[0] + ".lab" |
| | ) |
| | with open(lab_path, "w", encoding="utf-8") as f: |
| | f.writelines(labformat) |
| |
|
| | textgrid_path = os.path.join( |
| | os.path.dirname(args.result_file), key[0] + ".TextGrid" |
| | ) |
| | generator_textgrid( |
| | maxtime=(len(alignment) + 1) * 0.01 * subsample, |
| | lines=labformat, |
| | output=textgrid_path, |
| | ) |
| |
|