| from typing import List, Tuple
|
|
|
| import numpy as np
|
| import librosa
|
| import torch
|
| import torch.nn.functional as F
|
| from s3tokenizer.utils import padding
|
| from s3tokenizer.model_v2 import (
|
| S3TokenizerV2,
|
| ModelConfig,
|
| )
|
|
|
|
|
|
|
| S3_SR = 16_000
|
| S3_HOP = 160
|
| S3_TOKEN_HOP = 640
|
| S3_TOKEN_RATE = 25
|
| SPEECH_VOCAB_SIZE = 6561
|
|
|
|
|
| class S3Tokenizer(S3TokenizerV2):
|
| """
|
| s3tokenizer.S3TokenizerV2 with the following changes:
|
| - a more integrated `forward`
|
| - compute `log_mel_spectrogram` using `_mel_filters` and `window` in `register_buffers`
|
| """
|
|
|
| ignore_state_dict_missing = ("_mel_filters", "window")
|
|
|
| def __init__(
|
| self,
|
| name: str="speech_tokenizer_v2_25hz",
|
| config: ModelConfig = ModelConfig()
|
| ):
|
| super().__init__(name)
|
|
|
| self.n_fft = 400
|
| _mel_filters = librosa.filters.mel(
|
| sr=S3_SR,
|
| n_fft=self.n_fft,
|
| n_mels=config.n_mels
|
| )
|
| self.register_buffer(
|
| "_mel_filters",
|
| torch.FloatTensor(_mel_filters),
|
| )
|
|
|
| self.register_buffer(
|
| "window",
|
| torch.hann_window(self.n_fft),
|
| )
|
|
|
| def pad(self, wavs, sr) -> List[torch.Tensor]:
|
| """
|
| Given a list of wavs with the same `sample_rate`, pad them so that the length is multiple of 40ms (S3 runs at 25 token/sec).
|
| """
|
| processed_wavs = []
|
| for wav in wavs:
|
| if isinstance(wav, np.ndarray):
|
| wav = torch.from_numpy(wav)
|
| if wav.dim() == 1:
|
| wav = wav.unsqueeze(0)
|
|
|
| n_tokens = (wav.shape[1] / sr) * S3_TOKEN_RATE
|
| n_tokens = np.ceil(n_tokens)
|
| intended_wav_len = n_tokens * (sr / S3_TOKEN_RATE)
|
| intended_wav_len = int(intended_wav_len)
|
| wav = torch.nn.functional.pad(
|
| wav,
|
| (0, intended_wav_len - wav.shape[-1]),
|
| mode="constant",
|
| value=0
|
| )
|
| processed_wavs.append(wav)
|
| return processed_wavs
|
|
|
| def _prepare_audio(self, wavs):
|
| """Prepare a list of audios for s3tokenizer processing."""
|
| processed_wavs = []
|
| for wav in wavs:
|
| if isinstance(wav, np.ndarray):
|
| wav = torch.from_numpy(wav)
|
| if wav.dim() == 1:
|
| wav = wav.unsqueeze(0)
|
|
|
| processed_wavs.append(wav)
|
| return processed_wavs
|
|
|
| @torch.no_grad()
|
| def forward(
|
| self,
|
| wavs: torch.Tensor,
|
| accelerator: 'Accelerator'=None,
|
| max_len: int=None,
|
| ) -> Tuple[torch.Tensor, torch.LongTensor]:
|
| """
|
| NOTE: mel-spec has a hop size of 160 points (100 frame/sec).
|
| FIXME: this class inherits `nn.Module` but doesn't accept `torch.Tensor` and handles a list of wavs one by one, which is unexpected.
|
|
|
| Args
|
| ----
|
| - `wavs`: 16 kHz speech audio
|
| - `max_len` max length to truncate the output sequence to (25 token/sec).
|
| NOTE: please pad the waveform if longer sequence is needed.
|
| """
|
| processed_wavs = self._prepare_audio(wavs)
|
| mels, mel_lens = [], []
|
| for wav in processed_wavs:
|
| wav = wav.to(self.device)
|
| mel = self.log_mel_spectrogram(wav)
|
| if max_len is not None:
|
| mel = mel[..., :max_len * 4]
|
| mels.append(mel.squeeze(0))
|
|
|
| mels, mel_lens = padding(mels)
|
| if accelerator is None:
|
| tokenizer = self
|
| else:
|
| tokenizer = accelerator.unwrap_model(self)
|
|
|
| speech_tokens, speech_token_lens = tokenizer.quantize(mels, mel_lens.to(self.device))
|
| return (
|
| speech_tokens.long().detach(),
|
| speech_token_lens.long().detach(),
|
| )
|
|
|
| def log_mel_spectrogram(
|
| self,
|
| audio: torch.Tensor,
|
| padding: int = 0,
|
| ):
|
| """
|
| Compute the log-Mel spectrogram of
|
|
|
| Parameters
|
| ----------
|
| audio: torch.Tensor, shape = (*)
|
| The path to audio or either a NumPy array or Tensor containing the
|
| audio waveform in 16 kHz
|
|
|
| padding: int
|
| Number of zero samples to pad to the right
|
|
|
| Returns
|
| -------
|
| torch.Tensor, shape = (128, n_frames)
|
| A Tensor that contains the Mel spectrogram
|
| """
|
| if not torch.is_tensor(audio):
|
| audio = torch.from_numpy(audio)
|
|
|
| audio = audio.to(self.device)
|
| if padding > 0:
|
| audio = F.pad(audio, (0, padding))
|
| stft = torch.stft(
|
| audio, self.n_fft, S3_HOP,
|
| window=self.window.to(self.device),
|
| return_complex=True
|
| )
|
| magnitudes = stft[..., :-1].abs()**2
|
|
|
| mel_spec = self._mel_filters.to(self.device) @ magnitudes
|
|
|
| log_spec = torch.clamp(mel_spec, min=1e-10).log10()
|
| log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
|
| log_spec = (log_spec + 4.0) / 4.0
|
| return log_spec
|
|
|