| from typing import List, Tuple |
|
|
| import numpy as np |
| import librosa |
| import torch |
| import torch.nn.functional as F |
| from s3tokenizer.utils import padding |
| from s3tokenizer.model_v2 import ( |
| S3TokenizerV2, |
| ModelConfig, |
| ) |
|
|
|
|
| |
| S3_SR = 16_000 |
| S3_HOP = 160 |
| S3_TOKEN_HOP = 640 |
| S3_TOKEN_RATE = 25 |
| SPEECH_VOCAB_SIZE = 6561 |
|
|
|
|
| class S3Tokenizer(S3TokenizerV2): |
| """ |
| s3tokenizer.S3TokenizerV2 with the following changes: |
| - a more integrated `forward` |
| - compute `log_mel_spectrogram` using `_mel_filters` and `window` in `register_buffers` |
| """ |
|
|
| ignore_state_dict_missing = ("_mel_filters", "window") |
|
|
| def __init__( |
| self, |
| name: str="speech_tokenizer_v2_25hz", |
| config: ModelConfig = ModelConfig() |
| ): |
| super().__init__(name) |
|
|
| self.n_fft = 400 |
| _mel_filters = librosa.filters.mel( |
| sr=S3_SR, |
| n_fft=self.n_fft, |
| n_mels=config.n_mels |
| ) |
| self.register_buffer( |
| "_mel_filters", |
| torch.FloatTensor(_mel_filters), |
| ) |
|
|
| self.register_buffer( |
| "window", |
| torch.hann_window(self.n_fft), |
| ) |
|
|
| def pad(self, wavs, sr) -> List[torch.Tensor]: |
| """ |
| Given a list of wavs with the same `sample_rate`, pad them so that the length is multiple of 40ms (S3 runs at 25 token/sec). |
| """ |
| processed_wavs = [] |
| for wav in wavs: |
| if isinstance(wav, np.ndarray): |
| wav = torch.from_numpy(wav) |
| if wav.dim() == 1: |
| wav = wav.unsqueeze(0) |
|
|
| n_tokens = (wav.shape[1] / sr) * S3_TOKEN_RATE |
| n_tokens = np.ceil(n_tokens) |
| intended_wav_len = n_tokens * (sr / S3_TOKEN_RATE) |
| intended_wav_len = int(intended_wav_len) |
| wav = torch.nn.functional.pad( |
| wav, |
| (0, intended_wav_len - wav.shape[-1]), |
| mode="constant", |
| value=0 |
| ) |
| processed_wavs.append(wav) |
| return processed_wavs |
|
|
| def _prepare_audio(self, wavs): |
| """Prepare a list of audios for s3tokenizer processing.""" |
| processed_wavs = [] |
| for wav in wavs: |
| if isinstance(wav, np.ndarray): |
| wav = torch.from_numpy(wav) |
| if wav.dim() == 1: |
| wav = wav.unsqueeze(0) |
|
|
| processed_wavs.append(wav) |
| return processed_wavs |
|
|
| @torch.no_grad() |
| def forward( |
| self, |
| wavs: torch.Tensor, |
| accelerator: 'Accelerator'=None, |
| max_len: int=None, |
| ) -> Tuple[torch.Tensor, torch.LongTensor]: |
| """ |
| NOTE: mel-spec has a hop size of 160 points (100 frame/sec). |
| FIXME: this class inherits `nn.Module` but doesn't accept `torch.Tensor` and handles a list of wavs one by one, which is unexpected. |
| |
| Args |
| ---- |
| - `wavs`: 16 kHz speech audio |
| - `max_len` max length to truncate the output sequence to (25 token/sec). |
| NOTE: please pad the waveform if longer sequence is needed. |
| """ |
| processed_wavs = self._prepare_audio(wavs) |
| mels, mel_lens = [], [] |
| for wav in processed_wavs: |
| wav = wav.to(self.device) |
| mel = self.log_mel_spectrogram(wav) |
| if max_len is not None: |
| mel = mel[..., :max_len * 4] |
| mels.append(mel.squeeze(0)) |
|
|
| mels, mel_lens = padding(mels) |
| if accelerator is None: |
| tokenizer = self |
| else: |
| tokenizer = accelerator.unwrap_model(self) |
|
|
| speech_tokens, speech_token_lens = tokenizer.quantize(mels, mel_lens.to(self.device)) |
| return ( |
| speech_tokens.long().detach(), |
| speech_token_lens.long().detach(), |
| ) |
|
|
| def log_mel_spectrogram( |
| self, |
| audio: torch.Tensor, |
| padding: int = 0, |
| ): |
| """ |
| Compute the log-Mel spectrogram of |
| |
| Parameters |
| ---------- |
| audio: torch.Tensor, shape = (*) |
| The path to audio or either a NumPy array or Tensor containing the |
| audio waveform in 16 kHz |
| |
| padding: int |
| Number of zero samples to pad to the right |
| |
| Returns |
| ------- |
| torch.Tensor, shape = (128, n_frames) |
| A Tensor that contains the Mel spectrogram |
| """ |
| if not torch.is_tensor(audio): |
| audio = torch.from_numpy(audio) |
|
|
| audio = audio.to(self.device) |
| if padding > 0: |
| audio = F.pad(audio, (0, padding)) |
| stft = torch.stft( |
| audio, self.n_fft, S3_HOP, |
| window=self.window.to(self.device), |
| return_complex=True |
| ) |
| magnitudes = stft[..., :-1].abs()**2 |
|
|
| mel_spec = self._mel_filters.to(self.device) @ magnitudes |
|
|
| log_spec = torch.clamp(mel_spec, min=1e-10).log10() |
| log_spec = torch.maximum(log_spec, log_spec.max() - 8.0) |
| log_spec = (log_spec + 4.0) / 4.0 |
| return log_spec |
|
|