Spaces:
Running
Running
| #!/usr/bin/env python3 | |
| # Copyright (c) Facebook, Inc. and its affiliates. | |
| # | |
| # This source code is licensed under the MIT license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| from concurrent.futures import ThreadPoolExecutor | |
| import logging | |
| from omegaconf import MISSING | |
| import os | |
| import torch | |
| from typing import Optional | |
| import warnings | |
| from dataclasses import dataclass | |
| from fairseq.dataclass import FairseqDataclass | |
| from .kaldi_initializer import KaldiInitializerConfig, initalize_kaldi | |
| logger = logging.getLogger(__name__) | |
| class KaldiDecoderConfig(FairseqDataclass): | |
| hlg_graph_path: Optional[str] = None | |
| output_dict: str = MISSING | |
| kaldi_initializer_config: Optional[KaldiInitializerConfig] = None | |
| acoustic_scale: float = 0.5 | |
| max_active: int = 10000 | |
| beam_delta: float = 0.5 | |
| hash_ratio: float = 2.0 | |
| is_lattice: bool = False | |
| lattice_beam: float = 10.0 | |
| prune_interval: int = 25 | |
| determinize_lattice: bool = True | |
| prune_scale: float = 0.1 | |
| max_mem: int = 0 | |
| phone_determinize: bool = True | |
| word_determinize: bool = True | |
| minimize: bool = True | |
| num_threads: int = 1 | |
| class KaldiDecoder(object): | |
| def __init__( | |
| self, | |
| cfg: KaldiDecoderConfig, | |
| beam: int, | |
| nbest: int = 1, | |
| ): | |
| try: | |
| from kaldi.asr import FasterRecognizer, LatticeFasterRecognizer | |
| from kaldi.base import set_verbose_level | |
| from kaldi.decoder import ( | |
| FasterDecoder, | |
| FasterDecoderOptions, | |
| LatticeFasterDecoder, | |
| LatticeFasterDecoderOptions, | |
| ) | |
| from kaldi.lat.functions import DeterminizeLatticePhonePrunedOptions | |
| from kaldi.fstext import read_fst_kaldi, SymbolTable | |
| except: | |
| warnings.warn( | |
| "pykaldi is required for this functionality. Please install from https://github.com/pykaldi/pykaldi" | |
| ) | |
| # set_verbose_level(2) | |
| self.acoustic_scale = cfg.acoustic_scale | |
| self.nbest = nbest | |
| if cfg.hlg_graph_path is None: | |
| assert ( | |
| cfg.kaldi_initializer_config is not None | |
| ), "Must provide hlg graph path or kaldi initializer config" | |
| cfg.hlg_graph_path = initalize_kaldi(cfg.kaldi_initializer_config) | |
| assert os.path.exists(cfg.hlg_graph_path), cfg.hlg_graph_path | |
| if cfg.is_lattice: | |
| self.dec_cls = LatticeFasterDecoder | |
| opt_cls = LatticeFasterDecoderOptions | |
| self.rec_cls = LatticeFasterRecognizer | |
| else: | |
| assert self.nbest == 1, "nbest > 1 requires lattice decoder" | |
| self.dec_cls = FasterDecoder | |
| opt_cls = FasterDecoderOptions | |
| self.rec_cls = FasterRecognizer | |
| self.decoder_options = opt_cls() | |
| self.decoder_options.beam = beam | |
| self.decoder_options.max_active = cfg.max_active | |
| self.decoder_options.beam_delta = cfg.beam_delta | |
| self.decoder_options.hash_ratio = cfg.hash_ratio | |
| if cfg.is_lattice: | |
| self.decoder_options.lattice_beam = cfg.lattice_beam | |
| self.decoder_options.prune_interval = cfg.prune_interval | |
| self.decoder_options.determinize_lattice = cfg.determinize_lattice | |
| self.decoder_options.prune_scale = cfg.prune_scale | |
| det_opts = DeterminizeLatticePhonePrunedOptions() | |
| det_opts.max_mem = cfg.max_mem | |
| det_opts.phone_determinize = cfg.phone_determinize | |
| det_opts.word_determinize = cfg.word_determinize | |
| det_opts.minimize = cfg.minimize | |
| self.decoder_options.det_opts = det_opts | |
| self.output_symbols = {} | |
| with open(cfg.output_dict, "r") as f: | |
| for line in f: | |
| items = line.rstrip().split() | |
| assert len(items) == 2 | |
| self.output_symbols[int(items[1])] = items[0] | |
| logger.info(f"Loading FST from {cfg.hlg_graph_path}") | |
| self.fst = read_fst_kaldi(cfg.hlg_graph_path) | |
| self.symbol_table = SymbolTable.read_text(cfg.output_dict) | |
| self.executor = ThreadPoolExecutor(max_workers=cfg.num_threads) | |
| def generate(self, models, sample, **unused): | |
| """Generate a batch of inferences.""" | |
| # model.forward normally channels prev_output_tokens into the decoder | |
| # separately, but SequenceGenerator directly calls model.encoder | |
| encoder_input = { | |
| k: v for k, v in sample["net_input"].items() if k != "prev_output_tokens" | |
| } | |
| emissions, padding = self.get_emissions(models, encoder_input) | |
| return self.decode(emissions, padding) | |
| def get_emissions(self, models, encoder_input): | |
| """Run encoder and normalize emissions""" | |
| model = models[0] | |
| all_encoder_out = [m(**encoder_input) for m in models] | |
| if len(all_encoder_out) > 1: | |
| if "encoder_out" in all_encoder_out[0]: | |
| encoder_out = { | |
| "encoder_out": sum(e["encoder_out"] for e in all_encoder_out) | |
| / len(all_encoder_out), | |
| "encoder_padding_mask": all_encoder_out[0]["encoder_padding_mask"], | |
| } | |
| padding = encoder_out["encoder_padding_mask"] | |
| else: | |
| encoder_out = { | |
| "logits": sum(e["logits"] for e in all_encoder_out) | |
| / len(all_encoder_out), | |
| "padding_mask": all_encoder_out[0]["padding_mask"], | |
| } | |
| padding = encoder_out["padding_mask"] | |
| else: | |
| encoder_out = all_encoder_out[0] | |
| padding = ( | |
| encoder_out["padding_mask"] | |
| if "padding_mask" in encoder_out | |
| else encoder_out["encoder_padding_mask"] | |
| ) | |
| if hasattr(model, "get_logits"): | |
| emissions = model.get_logits(encoder_out, normalize=True) | |
| else: | |
| emissions = model.get_normalized_probs(encoder_out, log_probs=True) | |
| return ( | |
| emissions.cpu().float().transpose(0, 1), | |
| padding.cpu() if padding is not None and padding.any() else None, | |
| ) | |
| def decode_one(self, logits, padding): | |
| from kaldi.matrix import Matrix | |
| decoder = self.dec_cls(self.fst, self.decoder_options) | |
| asr = self.rec_cls( | |
| decoder, self.symbol_table, acoustic_scale=self.acoustic_scale | |
| ) | |
| if padding is not None: | |
| logits = logits[~padding] | |
| mat = Matrix(logits.numpy()) | |
| out = asr.decode(mat) | |
| if self.nbest > 1: | |
| from kaldi.fstext import shortestpath | |
| from kaldi.fstext.utils import ( | |
| convert_compact_lattice_to_lattice, | |
| convert_lattice_to_std, | |
| convert_nbest_to_list, | |
| get_linear_symbol_sequence, | |
| ) | |
| lat = out["lattice"] | |
| sp = shortestpath(lat, nshortest=self.nbest) | |
| sp = convert_compact_lattice_to_lattice(sp) | |
| sp = convert_lattice_to_std(sp) | |
| seq = convert_nbest_to_list(sp) | |
| results = [] | |
| for s in seq: | |
| _, o, w = get_linear_symbol_sequence(s) | |
| words = list(self.output_symbols[z] for z in o) | |
| results.append( | |
| { | |
| "tokens": words, | |
| "words": words, | |
| "score": w.value, | |
| "emissions": logits, | |
| } | |
| ) | |
| return results | |
| else: | |
| words = out["text"].split() | |
| return [ | |
| { | |
| "tokens": words, | |
| "words": words, | |
| "score": out["likelihood"], | |
| "emissions": logits, | |
| } | |
| ] | |
| def decode(self, emissions, padding): | |
| if padding is None: | |
| padding = [None] * len(emissions) | |
| ret = list( | |
| map( | |
| lambda e, p: self.executor.submit(self.decode_one, e, p), | |
| emissions, | |
| padding, | |
| ) | |
| ) | |
| return ret | |