STAR / fairseq /models /speech_to_text /hub_interface.py
Yixuan Li
add fairseq folder
85ba398
raw
history blame
5.01 kB
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import logging
from argparse import Namespace
from typing import Optional, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
import fairseq.data.audio.feature_transforms.utterance_cmvn as utt_cmvn
from fairseq.data import encoders
from fairseq.data.audio.audio_utils import convert_waveform as convert_wav
from fairseq.data.audio.audio_utils import get_fbank
from fairseq.data.audio.audio_utils import get_waveform as get_wav
from fairseq.data.audio.speech_to_text_dataset import SpeechToTextDataset
logger = logging.getLogger(__name__)
class S2THubInterface(nn.Module):
def __init__(self, cfg, task, model):
super().__init__()
self.cfg = cfg
self.task = task
self.model = model
self.model.eval()
self.generator = self.task.build_generator([self.model], self.cfg.generation)
@classmethod
def get_model_input(cls, task, audio: Union[str, torch.Tensor]):
input_type = task.data_cfg.hub.get("input_type", "fbank80")
if input_type == "fbank80_w_utt_cmvn":
if isinstance(audio, str):
feat = utt_cmvn.UtteranceCMVN()(get_fbank(audio))
feat = feat.unsqueeze(0) # T x D -> 1 x T x D
else:
import torchaudio.compliance.kaldi as kaldi
feat = kaldi.fbank(audio, num_mel_bins=80).numpy() # 1 x T x D
elif input_type in {"waveform", "standardized_waveform"}:
if isinstance(audio, str):
feat, sr = get_wav(audio) # C x T
feat, _ = convert_wav(
feat, sr, to_sample_rate=16_000, to_mono=True
) # C x T -> 1 x T
else:
feat = audio.numpy()
else:
raise ValueError(f"Unknown value: input_type = {input_type}")
src_lengths = torch.Tensor([feat.shape[1]]).long()
src_tokens = torch.from_numpy(feat) # 1 x T (x D)
if input_type == "standardized_waveform":
with torch.no_grad():
src_tokens = F.layer_norm(src_tokens, src_tokens.shape)
return {
"net_input": {
"src_tokens": src_tokens,
"src_lengths": src_lengths,
"prev_output_tokens": None,
},
"target_lengths": None,
"speaker": None,
}
@classmethod
def detokenize(cls, task, tokens):
text = task.tgt_dict.string(tokens)
tkn_cfg = task.data_cfg.bpe_tokenizer
tokenizer = encoders.build_bpe(Namespace(**tkn_cfg))
return text if tokenizer is None else tokenizer.decode(text)
@classmethod
def get_prefix_token(cls, task, lang):
prefix_size = int(task.data_cfg.prepend_tgt_lang_tag)
prefix_tokens = None
if prefix_size > 0:
assert lang is not None
lang_tag = SpeechToTextDataset.get_lang_tag_idx(lang, task.tgt_dict)
prefix_tokens = torch.Tensor([lang_tag]).long().unsqueeze(0)
return prefix_tokens
@classmethod
def get_prediction(
cls, task, model, generator, sample, tgt_lang=None, synthesize_speech=False
) -> Union[str, Tuple[str, Tuple[torch.Tensor, int]]]:
_tgt_lang = tgt_lang or task.data_cfg.hub.get("tgt_lang", None)
prefix = cls.get_prefix_token(task, _tgt_lang)
pred_tokens = generator.generate([model], sample, prefix_tokens=prefix)
pred = cls.detokenize(task, pred_tokens[0][0]["tokens"])
eos_token = task.data_cfg.config.get("eos_token", None)
if eos_token:
pred = " ".join(pred.split(" ")[:-1])
if synthesize_speech:
pfx = f"{_tgt_lang}_" if task.data_cfg.prepend_tgt_lang_tag else ""
tts_model_id = task.data_cfg.hub.get(f"{pfx}tts_model_id", None)
speaker = task.data_cfg.hub.get(f"{pfx}speaker", None)
if tts_model_id is None:
logger.warning("TTS model configuration not found")
else:
_repo, _id = tts_model_id.split(":")
tts_model = torch.hub.load(_repo, _id, verbose=False)
pred = (pred, tts_model.predict(pred, speaker=speaker))
return pred
def predict(
self,
audio: Union[str, torch.Tensor],
tgt_lang: Optional[str] = None,
synthesize_speech: bool = False,
) -> Union[str, Tuple[str, Tuple[torch.Tensor, int]]]:
# `audio` is either a file path or a 1xT Tensor
# return either text or (text, synthetic speech)
sample = self.get_model_input(self.task, audio)
return self.get_prediction(
self.task,
self.model,
self.generator,
sample,
tgt_lang=tgt_lang,
synthesize_speech=synthesize_speech,
)