| | |
| | |
| | |
| | |
| |
|
| | import json |
| | import os |
| | import re |
| | import sys |
| |
|
| | import torch |
| | from examples.speech_recognition.data import AsrDataset |
| | from examples.speech_recognition.data.replabels import replabel_symbol |
| | from fairseq.data import Dictionary |
| | from fairseq.tasks import LegacyFairseqTask, register_task |
| |
|
| |
|
| | def get_asr_dataset_from_json(data_json_path, tgt_dict): |
| | """ |
| | Parse data json and create dataset. |
| | See scripts/asr_prep_json.py which pack json from raw files |
| | |
| | Json example: |
| | { |
| | "utts": { |
| | "4771-29403-0025": { |
| | "input": { |
| | "length_ms": 170, |
| | "path": "/tmp/file1.flac" |
| | }, |
| | "output": { |
| | "text": "HELLO \n", |
| | "token": "HE LLO", |
| | "tokenid": "4815, 861" |
| | } |
| | }, |
| | "1564-142299-0096": { |
| | ... |
| | } |
| | } |
| | """ |
| | if not os.path.isfile(data_json_path): |
| | raise FileNotFoundError("Dataset not found: {}".format(data_json_path)) |
| | with open(data_json_path, "rb") as f: |
| | data_samples = json.load(f)["utts"] |
| | assert len(data_samples) != 0 |
| | sorted_samples = sorted( |
| | data_samples.items(), |
| | key=lambda sample: int(sample[1]["input"]["length_ms"]), |
| | reverse=True, |
| | ) |
| | aud_paths = [s[1]["input"]["path"] for s in sorted_samples] |
| | ids = [s[0] for s in sorted_samples] |
| | speakers = [] |
| | for s in sorted_samples: |
| | m = re.search("(.+?)-(.+?)-(.+?)", s[0]) |
| | speakers.append(m.group(1) + "_" + m.group(2)) |
| | frame_sizes = [s[1]["input"]["length_ms"] for s in sorted_samples] |
| | tgt = [ |
| | [int(i) for i in s[1]["output"]["tokenid"].split(", ")] |
| | for s in sorted_samples |
| | ] |
| | |
| | tgt = [[*t, tgt_dict.eos()] for t in tgt] |
| | return AsrDataset(aud_paths, frame_sizes, tgt, tgt_dict, ids, speakers) |
| |
|
| |
|
| | @register_task("speech_recognition") |
| | class SpeechRecognitionTask(LegacyFairseqTask): |
| | """ |
| | Task for training speech recognition model. |
| | """ |
| |
|
| | @staticmethod |
| | def add_args(parser): |
| | """Add task-specific arguments to the parser.""" |
| | parser.add_argument("data", help="path to data directory") |
| | parser.add_argument( |
| | "--silence-token", default="\u2581", help="token for silence (used by w2l)" |
| | ) |
| | parser.add_argument( |
| | "--max-source-positions", |
| | default=sys.maxsize, |
| | type=int, |
| | metavar="N", |
| | help="max number of frames in the source sequence", |
| | ) |
| | parser.add_argument( |
| | "--max-target-positions", |
| | default=1024, |
| | type=int, |
| | metavar="N", |
| | help="max number of tokens in the target sequence", |
| | ) |
| |
|
| | def __init__(self, args, tgt_dict): |
| | super().__init__(args) |
| | self.tgt_dict = tgt_dict |
| |
|
| | @classmethod |
| | def setup_task(cls, args, **kwargs): |
| | """Setup the task (e.g., load dictionaries).""" |
| | dict_path = os.path.join(args.data, "dict.txt") |
| | if not os.path.isfile(dict_path): |
| | raise FileNotFoundError("Dict not found: {}".format(dict_path)) |
| | tgt_dict = Dictionary.load(dict_path) |
| |
|
| | if args.criterion == "ctc_loss": |
| | tgt_dict.add_symbol("<ctc_blank>") |
| | elif args.criterion == "asg_loss": |
| | for i in range(1, args.max_replabel + 1): |
| | tgt_dict.add_symbol(replabel_symbol(i)) |
| |
|
| | print("| dictionary: {} types".format(len(tgt_dict))) |
| | return cls(args, tgt_dict) |
| |
|
| | def load_dataset(self, split, combine=False, **kwargs): |
| | """Load a given dataset split. |
| | |
| | Args: |
| | split (str): name of the split (e.g., train, valid, test) |
| | """ |
| | data_json_path = os.path.join(self.args.data, "{}.json".format(split)) |
| | self.datasets[split] = get_asr_dataset_from_json(data_json_path, self.tgt_dict) |
| |
|
| | def build_generator(self, models, args, **unused): |
| | w2l_decoder = getattr(args, "w2l_decoder", None) |
| | if w2l_decoder == "viterbi": |
| | from examples.speech_recognition.w2l_decoder import W2lViterbiDecoder |
| |
|
| | return W2lViterbiDecoder(args, self.target_dictionary) |
| | elif w2l_decoder == "kenlm": |
| | from examples.speech_recognition.w2l_decoder import W2lKenLMDecoder |
| |
|
| | return W2lKenLMDecoder(args, self.target_dictionary) |
| | elif w2l_decoder == "fairseqlm": |
| | from examples.speech_recognition.w2l_decoder import W2lFairseqLMDecoder |
| |
|
| | return W2lFairseqLMDecoder(args, self.target_dictionary) |
| | else: |
| | return super().build_generator(models, args) |
| |
|
| | @property |
| | def target_dictionary(self): |
| | """Return the :class:`~fairseq.data.Dictionary` for the language |
| | model.""" |
| | return self.tgt_dict |
| |
|
| | @property |
| | def source_dictionary(self): |
| | """Return the source :class:`~fairseq.data.Dictionary` (if applicable |
| | for this task).""" |
| | return None |
| |
|
| | def max_positions(self): |
| | """Return the max speech and sentence length allowed by the task.""" |
| | return (self.args.max_source_positions, self.args.max_target_positions) |
| |
|