|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import logging |
|
|
from argparse import Namespace |
|
|
from typing import Optional, Tuple, Union |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
|
|
|
import fairseq.data.audio.feature_transforms.utterance_cmvn as utt_cmvn |
|
|
from fairseq.data import encoders |
|
|
from fairseq.data.audio.audio_utils import convert_waveform as convert_wav |
|
|
from fairseq.data.audio.audio_utils import get_fbank |
|
|
from fairseq.data.audio.audio_utils import get_waveform as get_wav |
|
|
from fairseq.data.audio.speech_to_text_dataset import SpeechToTextDataset |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
class S2THubInterface(nn.Module): |
|
|
def __init__(self, cfg, task, model): |
|
|
super().__init__() |
|
|
self.cfg = cfg |
|
|
self.task = task |
|
|
self.model = model |
|
|
self.model.eval() |
|
|
self.generator = self.task.build_generator([self.model], self.cfg.generation) |
|
|
|
|
|
@classmethod |
|
|
def get_model_input(cls, task, audio: Union[str, torch.Tensor]): |
|
|
input_type = task.data_cfg.hub.get("input_type", "fbank80") |
|
|
if input_type == "fbank80_w_utt_cmvn": |
|
|
if isinstance(audio, str): |
|
|
feat = utt_cmvn.UtteranceCMVN()(get_fbank(audio)) |
|
|
feat = feat.unsqueeze(0) |
|
|
else: |
|
|
import torchaudio.compliance.kaldi as kaldi |
|
|
|
|
|
feat = kaldi.fbank(audio, num_mel_bins=80).numpy() |
|
|
elif input_type in {"waveform", "standardized_waveform"}: |
|
|
if isinstance(audio, str): |
|
|
feat, sr = get_wav(audio) |
|
|
feat, _ = convert_wav( |
|
|
feat, sr, to_sample_rate=16_000, to_mono=True |
|
|
) |
|
|
else: |
|
|
feat = audio.numpy() |
|
|
else: |
|
|
raise ValueError(f"Unknown value: input_type = {input_type}") |
|
|
|
|
|
src_lengths = torch.Tensor([feat.shape[1]]).long() |
|
|
src_tokens = torch.from_numpy(feat) |
|
|
if input_type == "standardized_waveform": |
|
|
with torch.no_grad(): |
|
|
src_tokens = F.layer_norm(src_tokens, src_tokens.shape) |
|
|
|
|
|
return { |
|
|
"net_input": { |
|
|
"src_tokens": src_tokens, |
|
|
"src_lengths": src_lengths, |
|
|
"prev_output_tokens": None, |
|
|
}, |
|
|
"target_lengths": None, |
|
|
"speaker": None, |
|
|
} |
|
|
|
|
|
@classmethod |
|
|
def detokenize(cls, task, tokens): |
|
|
text = task.tgt_dict.string(tokens) |
|
|
tkn_cfg = task.data_cfg.bpe_tokenizer |
|
|
tokenizer = encoders.build_bpe(Namespace(**tkn_cfg)) |
|
|
return text if tokenizer is None else tokenizer.decode(text) |
|
|
|
|
|
@classmethod |
|
|
def get_prefix_token(cls, task, lang): |
|
|
prefix_size = int(task.data_cfg.prepend_tgt_lang_tag) |
|
|
prefix_tokens = None |
|
|
if prefix_size > 0: |
|
|
assert lang is not None |
|
|
lang_tag = SpeechToTextDataset.get_lang_tag_idx(lang, task.tgt_dict) |
|
|
prefix_tokens = torch.Tensor([lang_tag]).long().unsqueeze(0) |
|
|
return prefix_tokens |
|
|
|
|
|
@classmethod |
|
|
def get_prediction( |
|
|
cls, task, model, generator, sample, tgt_lang=None, synthesize_speech=False |
|
|
) -> Union[str, Tuple[str, Tuple[torch.Tensor, int]]]: |
|
|
_tgt_lang = tgt_lang or task.data_cfg.hub.get("tgt_lang", None) |
|
|
prefix = cls.get_prefix_token(task, _tgt_lang) |
|
|
pred_tokens = generator.generate([model], sample, prefix_tokens=prefix) |
|
|
pred = cls.detokenize(task, pred_tokens[0][0]["tokens"]) |
|
|
eos_token = task.data_cfg.config.get("eos_token", None) |
|
|
if eos_token: |
|
|
pred = " ".join(pred.split(" ")[:-1]) |
|
|
|
|
|
if synthesize_speech: |
|
|
pfx = f"{_tgt_lang}_" if task.data_cfg.prepend_tgt_lang_tag else "" |
|
|
tts_model_id = task.data_cfg.hub.get(f"{pfx}tts_model_id", None) |
|
|
speaker = task.data_cfg.hub.get(f"{pfx}speaker", None) |
|
|
if tts_model_id is None: |
|
|
logger.warning("TTS model configuration not found") |
|
|
else: |
|
|
_repo, _id = tts_model_id.split(":") |
|
|
tts_model = torch.hub.load(_repo, _id, verbose=False) |
|
|
pred = (pred, tts_model.predict(pred, speaker=speaker)) |
|
|
return pred |
|
|
|
|
|
def predict( |
|
|
self, |
|
|
audio: Union[str, torch.Tensor], |
|
|
tgt_lang: Optional[str] = None, |
|
|
synthesize_speech: bool = False, |
|
|
) -> Union[str, Tuple[str, Tuple[torch.Tensor, int]]]: |
|
|
|
|
|
|
|
|
sample = self.get_model_input(self.task, audio) |
|
|
return self.get_prediction( |
|
|
self.task, |
|
|
self.model, |
|
|
self.generator, |
|
|
sample, |
|
|
tgt_lang=tgt_lang, |
|
|
synthesize_speech=synthesize_speech, |
|
|
) |
|
|
|