| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| from __future__ import print_function |
|
|
| import argparse |
| import copy |
| import logging |
| import os |
| import sys |
|
|
| import torch |
| import yaml |
| from torch.utils.data import DataLoader |
| from textgrid import TextGrid, IntervalTier |
| import math |
|
|
| from wenet.dataset.dataset import Dataset |
| from wenet.utils.ctc_utils import force_align |
| from wenet.utils.common import get_subsample |
| from wenet.utils.init_model import init_model |
| from wenet.utils.init_tokenizer import init_tokenizer |
|
|
|
|
| def generator_textgrid(maxtime, lines, output): |
| |
| interval = maxtime / (len(lines) + 1) |
| margin = 0.0001 |
|
|
| tg = TextGrid(maxTime=maxtime) |
| linetier = IntervalTier(name="line", maxTime=maxtime) |
|
|
| i = 0 |
| for l in lines: |
| s, e, w = l.split() |
| linetier.add(minTime=float(s) + margin, maxTime=float(e), mark=w) |
|
|
| tg.append(linetier) |
| print("successfully generator {}".format(output)) |
| tg.write(output) |
|
|
|
|
| def get_frames_timestamp(alignment, |
| prob, |
| blank_thres=0.999, |
| thres=0.0000000001): |
| |
| |
| timestamp = [] |
| |
| start = 0 |
| end = 0 |
| local_start = 0 |
| while end < len(alignment): |
| while end < len(alignment) and alignment[end] == 0: |
| end += 1 |
| if end == len(alignment): |
| timestamp[-1] += alignment[start:] |
| break |
| end += 1 |
| while end < len(alignment) and alignment[end - 1] == alignment[end]: |
| end += 1 |
| local_start = end - 1 |
| |
| while local_start >= start and ( |
| prob[local_start][0] < math.log(blank_thres) |
| or prob[local_start][alignment[end - 1]] > math.log(thres)): |
| alignment[local_start] = alignment[end - 1] |
| local_start -= 1 |
| cur_alignment = alignment[start:end] |
| timestamp.append(cur_alignment) |
| start = end |
| return timestamp |
|
|
|
|
| def get_labformat(timestamp, subsample): |
| begin = 0 |
| begin_time = 0 |
| duration = 0 |
| labformat = [] |
| for idx, t in enumerate(timestamp): |
| |
| subsample = get_subsample(configs) |
| |
| i = 0 |
| while t[i] == 0: |
| i += 1 |
| begin = i |
| dur = 0 |
| while i < len(t) and t[i] != 0: |
| i += 1 |
| dur += 1 |
| begin = begin_time + begin * 0.01 * subsample |
| duration = dur * 0.01 * subsample |
| if idx < len(timestamp) - 1: |
| print("{:.2f} {:.2f} {}".format(begin, begin + duration, |
| char_dict[t[-1]])) |
| labformat.append("{:.2f} {:.2f} {}\n".format( |
| begin, begin + duration, char_dict[t[-1]])) |
| else: |
| non_blank = 0 |
| for i in t: |
| if i != 0: |
| token = i |
| break |
| print("{:.2f} {:.2f} {}".format(begin, begin + duration, |
| char_dict[token])) |
| labformat.append("{:.2f} {:.2f} {}\n".format( |
| begin, begin + duration, char_dict[token])) |
| begin_time += len(t) * 0.01 * subsample |
| return labformat |
|
|
|
|
| if __name__ == '__main__': |
| parser = argparse.ArgumentParser( |
| description='use ctc to generate alignment') |
| parser.add_argument('--config', required=True, help='config file') |
| parser.add_argument('--input_file', required=True, help='format data file') |
| parser.add_argument('--data_type', |
| default='raw', |
| choices=['raw', 'shard'], |
| help='train and cv data type') |
| parser.add_argument('--gpu', |
| type=int, |
| default=-1, |
| help='gpu id for this rank, -1 for cpu') |
| parser.add_argument('--device', |
| type=str, |
| default="cpu", |
| choices=["cpu", "npu", "cuda"], |
| help='accelerator to use') |
| parser.add_argument('--blank_thres', |
| default=0.999999, |
| type=float, |
| help='ctc blank thes') |
| parser.add_argument('--thres', |
| default=0.000001, |
| type=float, |
| help='ctc non blank thes') |
| parser.add_argument('--checkpoint', required=True, help='checkpoint model') |
| parser.add_argument('--dict', required=True, help='dict file') |
| parser.add_argument( |
| '--non_lang_syms', |
| help="non-linguistic symbol file. One symbol per line.") |
| parser.add_argument('--result_file', |
| required=True, |
| help='alignment result file') |
| parser.add_argument('--batch_size', type=int, default=1, help='batch size') |
| parser.add_argument('--gen_praat', |
| action='store_true', |
| help='convert alignment to a praat format') |
| parser.add_argument('--bpe_model', |
| default=None, |
| type=str, |
| help='bpe model for english part') |
|
|
| args = parser.parse_args() |
| print(args) |
| logging.basicConfig(level=logging.DEBUG, |
| format='%(asctime)s %(levelname)s %(message)s') |
| if args.gpu != -1: |
| |
| args.device = "cuda" |
| if "cuda" in args.device: |
| os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu) |
|
|
| if args.batch_size > 1: |
| logging.fatal('alignment mode must be running with batch_size == 1') |
| sys.exit(1) |
|
|
| with open(args.config, 'r') as fin: |
| configs = yaml.load(fin, Loader=yaml.FullLoader) |
|
|
| |
| char_dict = {} |
| with open(args.dict, 'r') as fin: |
| for line in fin: |
| arr = line.strip().split() |
| assert len(arr) == 2 |
| char_dict[int(arr[1])] = arr[0] |
| eos = len(char_dict) - 1 |
|
|
| |
| ali_conf = copy.deepcopy(configs['dataset_conf']) |
|
|
| ali_conf['filter_conf']['max_length'] = 102400 |
| ali_conf['filter_conf']['min_length'] = 0 |
| ali_conf['filter_conf']['token_max_length'] = 102400 |
| ali_conf['filter_conf']['token_min_length'] = 0 |
| ali_conf['filter_conf']['max_output_input_ratio'] = 102400 |
| ali_conf['filter_conf']['min_output_input_ratio'] = 0 |
| ali_conf['speed_perturb'] = False |
| ali_conf['spec_aug'] = False |
| ali_conf['spec_trim'] = False |
| ali_conf['shuffle'] = False |
| ali_conf['sort'] = False |
| ali_conf['fbank_conf']['dither'] = 0.0 |
| ali_conf['batch_conf']['batch_type'] = "static" |
| ali_conf['batch_conf']['batch_size'] = args.batch_size |
|
|
| tokenizer = init_tokenizer(configs) |
| ali_dataset = Dataset(args.data_type, |
| args.input_file, |
| tokenizer, |
| ali_conf, |
| partition=False) |
|
|
| ali_data_loader = DataLoader(ali_dataset, batch_size=None, num_workers=0) |
|
|
| |
| model, configs = init_model(args, configs) |
|
|
| device = torch.device(args.device) |
| model = model.to(device) |
|
|
| model.eval() |
| with torch.no_grad(), open(args.result_file, 'w', |
| encoding='utf-8') as fout: |
| for batch_idx, batch in enumerate(ali_data_loader): |
| print("#" * 80) |
| key, feat, target, feats_length, target_length = batch |
|
|
| feat = feat.to(device) |
| target = target.to(device) |
| feats_length = feats_length.to(device) |
| target_length = target_length.to(device) |
| |
| |
| encoder_out, encoder_mask = model._forward_encoder( |
| feat, feats_length) |
| maxlen = encoder_out.size(1) |
| ctc_probs = model.ctc.log_softmax( |
| encoder_out) |
| |
| ctc_probs = ctc_probs.squeeze(0) |
| target = target.squeeze(0) |
| alignment = force_align(ctc_probs, target) |
| fout.write('{} {}\n'.format(key[0], alignment)) |
|
|
| if args.gen_praat: |
| timestamp = get_frames_timestamp(alignment, ctc_probs, |
| args.blank_thres, args.thres) |
| subsample = get_subsample(configs) |
| labformat = get_labformat(timestamp, subsample) |
|
|
| lab_path = os.path.join(os.path.dirname(args.result_file), |
| key[0] + ".lab") |
| with open(lab_path, 'w', encoding='utf-8') as f: |
| f.writelines(labformat) |
|
|
| textgrid_path = os.path.join(os.path.dirname(args.result_file), |
| key[0] + ".TextGrid") |
| generator_textgrid(maxtime=(len(alignment) + 1) * 0.01 * |
| subsample, |
| lines=labformat, |
| output=textgrid_path) |
|
|