Spaces:
Running
Running
| # Copyright (c) Facebook, Inc. and its affiliates. | |
| # | |
| # This source code is licensed under the MIT license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import json | |
| import os | |
| import re | |
| import sys | |
| import torch | |
| from examples.speech_recognition.data import AsrDataset | |
| from examples.speech_recognition.data.replabels import replabel_symbol | |
| from fairseq.data import Dictionary | |
| from fairseq.tasks import LegacyFairseqTask, register_task | |
| def get_asr_dataset_from_json(data_json_path, tgt_dict): | |
| """ | |
| Parse data json and create dataset. | |
| See scripts/asr_prep_json.py which pack json from raw files | |
| Json example: | |
| { | |
| "utts": { | |
| "4771-29403-0025": { | |
| "input": { | |
| "length_ms": 170, | |
| "path": "/tmp/file1.flac" | |
| }, | |
| "output": { | |
| "text": "HELLO \n", | |
| "token": "HE LLO", | |
| "tokenid": "4815, 861" | |
| } | |
| }, | |
| "1564-142299-0096": { | |
| ... | |
| } | |
| } | |
| """ | |
| if not os.path.isfile(data_json_path): | |
| raise FileNotFoundError("Dataset not found: {}".format(data_json_path)) | |
| with open(data_json_path, "rb") as f: | |
| data_samples = json.load(f)["utts"] | |
| assert len(data_samples) != 0 | |
| sorted_samples = sorted( | |
| data_samples.items(), | |
| key=lambda sample: int(sample[1]["input"]["length_ms"]), | |
| reverse=True, | |
| ) | |
| aud_paths = [s[1]["input"]["path"] for s in sorted_samples] | |
| ids = [s[0] for s in sorted_samples] | |
| speakers = [] | |
| for s in sorted_samples: | |
| m = re.search("(.+?)-(.+?)-(.+?)", s[0]) | |
| speakers.append(m.group(1) + "_" + m.group(2)) | |
| frame_sizes = [s[1]["input"]["length_ms"] for s in sorted_samples] | |
| tgt = [ | |
| [int(i) for i in s[1]["output"]["tokenid"].split(", ")] | |
| for s in sorted_samples | |
| ] | |
| # append eos | |
| tgt = [[*t, tgt_dict.eos()] for t in tgt] | |
| return AsrDataset(aud_paths, frame_sizes, tgt, tgt_dict, ids, speakers) | |
| class SpeechRecognitionTask(LegacyFairseqTask): | |
| """ | |
| Task for training speech recognition model. | |
| """ | |
| def add_args(parser): | |
| """Add task-specific arguments to the parser.""" | |
| parser.add_argument("data", help="path to data directory") | |
| parser.add_argument( | |
| "--silence-token", default="\u2581", help="token for silence (used by w2l)" | |
| ) | |
| parser.add_argument( | |
| "--max-source-positions", | |
| default=sys.maxsize, | |
| type=int, | |
| metavar="N", | |
| help="max number of frames in the source sequence", | |
| ) | |
| parser.add_argument( | |
| "--max-target-positions", | |
| default=1024, | |
| type=int, | |
| metavar="N", | |
| help="max number of tokens in the target sequence", | |
| ) | |
| def __init__(self, args, tgt_dict): | |
| super().__init__(args) | |
| self.tgt_dict = tgt_dict | |
| def setup_task(cls, args, **kwargs): | |
| """Setup the task (e.g., load dictionaries).""" | |
| dict_path = os.path.join(args.data, "dict.txt") | |
| if not os.path.isfile(dict_path): | |
| raise FileNotFoundError("Dict not found: {}".format(dict_path)) | |
| tgt_dict = Dictionary.load(dict_path) | |
| if args.criterion == "ctc_loss": | |
| tgt_dict.add_symbol("<ctc_blank>") | |
| elif args.criterion == "asg_loss": | |
| for i in range(1, args.max_replabel + 1): | |
| tgt_dict.add_symbol(replabel_symbol(i)) | |
| print("| dictionary: {} types".format(len(tgt_dict))) | |
| return cls(args, tgt_dict) | |
| def load_dataset(self, split, combine=False, **kwargs): | |
| """Load a given dataset split. | |
| Args: | |
| split (str): name of the split (e.g., train, valid, test) | |
| """ | |
| data_json_path = os.path.join(self.args.data, "{}.json".format(split)) | |
| self.datasets[split] = get_asr_dataset_from_json(data_json_path, self.tgt_dict) | |
| def build_generator(self, models, args, **unused): | |
| w2l_decoder = getattr(args, "w2l_decoder", None) | |
| if w2l_decoder == "viterbi": | |
| from examples.speech_recognition.w2l_decoder import W2lViterbiDecoder | |
| return W2lViterbiDecoder(args, self.target_dictionary) | |
| elif w2l_decoder == "kenlm": | |
| from examples.speech_recognition.w2l_decoder import W2lKenLMDecoder | |
| return W2lKenLMDecoder(args, self.target_dictionary) | |
| elif w2l_decoder == "fairseqlm": | |
| from examples.speech_recognition.w2l_decoder import W2lFairseqLMDecoder | |
| return W2lFairseqLMDecoder(args, self.target_dictionary) | |
| else: | |
| return super().build_generator(models, args) | |
| def target_dictionary(self): | |
| """Return the :class:`~fairseq.data.Dictionary` for the language | |
| model.""" | |
| return self.tgt_dict | |
| def source_dictionary(self): | |
| """Return the source :class:`~fairseq.data.Dictionary` (if applicable | |
| for this task).""" | |
| return None | |
| def max_positions(self): | |
| """Return the max speech and sentence length allowed by the task.""" | |
| return (self.args.max_source_positions, self.args.max_target_positions) | |