| |
| |
| |
| |
| |
|
|
| from __future__ import absolute_import, division, print_function, unicode_literals |
|
|
| import argparse |
| import concurrent.futures |
| import json |
| import multiprocessing |
| import os |
| from collections import namedtuple |
| from itertools import chain |
|
|
| import sentencepiece as spm |
| from fairseq.data import Dictionary |
|
|
|
|
| MILLISECONDS_TO_SECONDS = 0.001 |
|
|
|
|
| def process_sample(aud_path, lable, utt_id, sp, tgt_dict): |
| import torchaudio |
|
|
| input = {} |
| output = {} |
| si, ei = torchaudio.info(aud_path) |
| input["length_ms"] = int( |
| si.length / si.channels / si.rate / MILLISECONDS_TO_SECONDS |
| ) |
| input["path"] = aud_path |
|
|
| token = " ".join(sp.EncodeAsPieces(lable)) |
| ids = tgt_dict.encode_line(token, append_eos=False) |
| output["text"] = lable |
| output["token"] = token |
| output["tokenid"] = ", ".join(map(str, [t.tolist() for t in ids])) |
| return {utt_id: {"input": input, "output": output}} |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser() |
| parser.add_argument( |
| "--audio-dirs", |
| nargs="+", |
| default=["-"], |
| required=True, |
| help="input directories with audio files", |
| ) |
| parser.add_argument( |
| "--labels", |
| required=True, |
| help="aggregated input labels with format <ID LABEL> per line", |
| type=argparse.FileType("r", encoding="UTF-8"), |
| ) |
| parser.add_argument( |
| "--spm-model", |
| required=True, |
| help="sentencepiece model to use for encoding", |
| type=argparse.FileType("r", encoding="UTF-8"), |
| ) |
| parser.add_argument( |
| "--dictionary", |
| required=True, |
| help="file to load fairseq dictionary from", |
| type=argparse.FileType("r", encoding="UTF-8"), |
| ) |
| parser.add_argument("--audio-format", choices=["flac", "wav"], default="wav") |
| parser.add_argument( |
| "--output", |
| required=True, |
| type=argparse.FileType("w"), |
| help="path to save json output", |
| ) |
| args = parser.parse_args() |
|
|
| sp = spm.SentencePieceProcessor() |
| sp.Load(args.spm_model.name) |
|
|
| tgt_dict = Dictionary.load(args.dictionary) |
|
|
| labels = {} |
| for line in args.labels: |
| (utt_id, label) = line.split(" ", 1) |
| labels[utt_id] = label |
| if len(labels) == 0: |
| raise Exception("No labels found in ", args.labels_path) |
|
|
| Sample = namedtuple("Sample", "aud_path utt_id") |
| samples = [] |
| for path, _, files in chain.from_iterable( |
| os.walk(path) for path in args.audio_dirs |
| ): |
| for f in files: |
| if f.endswith(args.audio_format): |
| if len(os.path.splitext(f)) != 2: |
| raise Exception("Expect <utt_id.extension> file name. Got: ", f) |
| utt_id = os.path.splitext(f)[0] |
| if utt_id not in labels: |
| continue |
| samples.append(Sample(os.path.join(path, f), utt_id)) |
|
|
| utts = {} |
| num_cpu = multiprocessing.cpu_count() |
| with concurrent.futures.ThreadPoolExecutor(max_workers=num_cpu) as executor: |
| future_to_sample = { |
| executor.submit( |
| process_sample, s.aud_path, labels[s.utt_id], s.utt_id, sp, tgt_dict |
| ): s |
| for s in samples |
| } |
| for future in concurrent.futures.as_completed(future_to_sample): |
| try: |
| data = future.result() |
| except Exception as exc: |
| print("generated an exception: ", exc) |
| else: |
| utts.update(data) |
| json.dump({"utts": utts}, args.output, indent=4) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|