Spaces:
Sleeping
Sleeping
| # Copyright (c) Facebook, Inc. and its affiliates. | |
| # | |
| # This source code is licensed under the MIT license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import logging | |
| from argparse import Namespace | |
| from typing import Optional, Tuple, Union | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import fairseq.data.audio.feature_transforms.utterance_cmvn as utt_cmvn | |
| from fairseq.data import encoders | |
| from fairseq.data.audio.audio_utils import convert_waveform as convert_wav | |
| from fairseq.data.audio.audio_utils import get_fbank | |
| from fairseq.data.audio.audio_utils import get_waveform as get_wav | |
| from fairseq.data.audio.speech_to_text_dataset import SpeechToTextDataset | |
| logger = logging.getLogger(__name__) | |
| class S2THubInterface(nn.Module): | |
| def __init__(self, cfg, task, model): | |
| super().__init__() | |
| self.cfg = cfg | |
| self.task = task | |
| self.model = model | |
| self.model.eval() | |
| self.generator = self.task.build_generator([self.model], self.cfg.generation) | |
| def get_model_input(cls, task, audio: Union[str, torch.Tensor]): | |
| input_type = task.data_cfg.hub.get("input_type", "fbank80") | |
| if input_type == "fbank80_w_utt_cmvn": | |
| if isinstance(audio, str): | |
| feat = utt_cmvn.UtteranceCMVN()(get_fbank(audio)) | |
| feat = feat.unsqueeze(0) # T x D -> 1 x T x D | |
| else: | |
| import torchaudio.compliance.kaldi as kaldi | |
| feat = kaldi.fbank(audio, num_mel_bins=80).numpy() # 1 x T x D | |
| elif input_type in {"waveform", "standardized_waveform"}: | |
| if isinstance(audio, str): | |
| feat, sr = get_wav(audio) # C x T | |
| feat, _ = convert_wav( | |
| feat, sr, to_sample_rate=16_000, to_mono=True | |
| ) # C x T -> 1 x T | |
| else: | |
| feat = audio.numpy() | |
| else: | |
| raise ValueError(f"Unknown value: input_type = {input_type}") | |
| src_lengths = torch.Tensor([feat.shape[1]]).long() | |
| src_tokens = torch.from_numpy(feat) # 1 x T (x D) | |
| if input_type == "standardized_waveform": | |
| with torch.no_grad(): | |
| src_tokens = F.layer_norm(src_tokens, src_tokens.shape) | |
| return { | |
| "net_input": { | |
| "src_tokens": src_tokens, | |
| "src_lengths": src_lengths, | |
| "prev_output_tokens": None, | |
| }, | |
| "target_lengths": None, | |
| "speaker": None, | |
| } | |
| def detokenize(cls, task, tokens): | |
| text = task.tgt_dict.string(tokens) | |
| tkn_cfg = task.data_cfg.bpe_tokenizer | |
| tokenizer = encoders.build_bpe(Namespace(**tkn_cfg)) | |
| return text if tokenizer is None else tokenizer.decode(text) | |
| def get_prefix_token(cls, task, lang): | |
| prefix_size = int(task.data_cfg.prepend_tgt_lang_tag) | |
| prefix_tokens = None | |
| if prefix_size > 0: | |
| assert lang is not None | |
| lang_tag = SpeechToTextDataset.get_lang_tag_idx(lang, task.tgt_dict) | |
| prefix_tokens = torch.Tensor([lang_tag]).long().unsqueeze(0) | |
| return prefix_tokens | |
| def get_prediction( | |
| cls, task, model, generator, sample, tgt_lang=None, synthesize_speech=False | |
| ) -> Union[str, Tuple[str, Tuple[torch.Tensor, int]]]: | |
| _tgt_lang = tgt_lang or task.data_cfg.hub.get("tgt_lang", None) | |
| prefix = cls.get_prefix_token(task, _tgt_lang) | |
| pred_tokens = generator.generate([model], sample, prefix_tokens=prefix) | |
| pred = cls.detokenize(task, pred_tokens[0][0]["tokens"]) | |
| eos_token = task.data_cfg.config.get("eos_token", None) | |
| if eos_token: | |
| pred = " ".join(pred.split(" ")[:-1]) | |
| if synthesize_speech: | |
| pfx = f"{_tgt_lang}_" if task.data_cfg.prepend_tgt_lang_tag else "" | |
| tts_model_id = task.data_cfg.hub.get(f"{pfx}tts_model_id", None) | |
| speaker = task.data_cfg.hub.get(f"{pfx}speaker", None) | |
| if tts_model_id is None: | |
| logger.warning("TTS model configuration not found") | |
| else: | |
| _repo, _id = tts_model_id.split(":") | |
| tts_model = torch.hub.load(_repo, _id, verbose=False) | |
| pred = (pred, tts_model.predict(pred, speaker=speaker)) | |
| return pred | |
| def predict( | |
| self, | |
| audio: Union[str, torch.Tensor], | |
| tgt_lang: Optional[str] = None, | |
| synthesize_speech: bool = False, | |
| ) -> Union[str, Tuple[str, Tuple[torch.Tensor, int]]]: | |
| # `audio` is either a file path or a 1xT Tensor | |
| # return either text or (text, synthetic speech) | |
| sample = self.get_model_input(self.task, audio) | |
| return self.get_prediction( | |
| self.task, | |
| self.model, | |
| self.generator, | |
| sample, | |
| tgt_lang=tgt_lang, | |
| synthesize_speech=synthesize_speech, | |
| ) | |