# Copyright (c) Facebook, Inc. and its affiliates. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import logging from argparse import Namespace from typing import Optional, Tuple, Union import torch import torch.nn as nn import torch.nn.functional as F import fairseq.data.audio.feature_transforms.utterance_cmvn as utt_cmvn from fairseq.data import encoders from fairseq.data.audio.audio_utils import convert_waveform as convert_wav from fairseq.data.audio.audio_utils import get_fbank from fairseq.data.audio.audio_utils import get_waveform as get_wav from fairseq.data.audio.speech_to_text_dataset import SpeechToTextDataset logger = logging.getLogger(__name__) class S2THubInterface(nn.Module): def __init__(self, cfg, task, model): super().__init__() self.cfg = cfg self.task = task self.model = model self.model.eval() self.generator = self.task.build_generator([self.model], self.cfg.generation) @classmethod def get_model_input(cls, task, audio: Union[str, torch.Tensor]): input_type = task.data_cfg.hub.get("input_type", "fbank80") if input_type == "fbank80_w_utt_cmvn": if isinstance(audio, str): feat = utt_cmvn.UtteranceCMVN()(get_fbank(audio)) feat = feat.unsqueeze(0) # T x D -> 1 x T x D else: import torchaudio.compliance.kaldi as kaldi feat = kaldi.fbank(audio, num_mel_bins=80).numpy() # 1 x T x D elif input_type in {"waveform", "standardized_waveform"}: if isinstance(audio, str): feat, sr = get_wav(audio) # C x T feat, _ = convert_wav( feat, sr, to_sample_rate=16_000, to_mono=True ) # C x T -> 1 x T else: feat = audio.numpy() else: raise ValueError(f"Unknown value: input_type = {input_type}") src_lengths = torch.Tensor([feat.shape[1]]).long() src_tokens = torch.from_numpy(feat) # 1 x T (x D) if input_type == "standardized_waveform": with torch.no_grad(): src_tokens = F.layer_norm(src_tokens, src_tokens.shape) return { "net_input": { "src_tokens": src_tokens, "src_lengths": src_lengths, "prev_output_tokens": None, }, "target_lengths": None, "speaker": None, } @classmethod def detokenize(cls, task, tokens): text = task.tgt_dict.string(tokens) tkn_cfg = task.data_cfg.bpe_tokenizer tokenizer = encoders.build_bpe(Namespace(**tkn_cfg)) return text if tokenizer is None else tokenizer.decode(text) @classmethod def get_prefix_token(cls, task, lang): prefix_size = int(task.data_cfg.prepend_tgt_lang_tag) prefix_tokens = None if prefix_size > 0: assert lang is not None lang_tag = SpeechToTextDataset.get_lang_tag_idx(lang, task.tgt_dict) prefix_tokens = torch.Tensor([lang_tag]).long().unsqueeze(0) return prefix_tokens @classmethod def get_prediction( cls, task, model, generator, sample, tgt_lang=None, synthesize_speech=False ) -> Union[str, Tuple[str, Tuple[torch.Tensor, int]]]: _tgt_lang = tgt_lang or task.data_cfg.hub.get("tgt_lang", None) prefix = cls.get_prefix_token(task, _tgt_lang) pred_tokens = generator.generate([model], sample, prefix_tokens=prefix) pred = cls.detokenize(task, pred_tokens[0][0]["tokens"]) eos_token = task.data_cfg.config.get("eos_token", None) if eos_token: pred = " ".join(pred.split(" ")[:-1]) if synthesize_speech: pfx = f"{_tgt_lang}_" if task.data_cfg.prepend_tgt_lang_tag else "" tts_model_id = task.data_cfg.hub.get(f"{pfx}tts_model_id", None) speaker = task.data_cfg.hub.get(f"{pfx}speaker", None) if tts_model_id is None: logger.warning("TTS model configuration not found") else: _repo, _id = tts_model_id.split(":") tts_model = torch.hub.load(_repo, _id, verbose=False) pred = (pred, tts_model.predict(pred, speaker=speaker)) return pred def predict( self, audio: Union[str, torch.Tensor], tgt_lang: Optional[str] = None, synthesize_speech: bool = False, ) -> Union[str, Tuple[str, Tuple[torch.Tensor, int]]]: # `audio` is either a file path or a 1xT Tensor # return either text or (text, synthetic speech) sample = self.get_model_input(self.task, audio) return self.get_prediction( self.task, self.model, self.generator, sample, tgt_lang=tgt_lang, synthesize_speech=synthesize_speech, )