| | from typing import List, Tuple |
| |
|
| | import numpy as np |
| | import librosa |
| | import torch |
| | import torch.nn.functional as F |
| | from s3tokenizer.utils import padding |
| | from s3tokenizer.model_v2 import ( |
| | S3TokenizerV2, |
| | ModelConfig, |
| | ) |
| |
|
| |
|
| | |
| | S3_SR = 16_000 |
| | S3_HOP = 160 |
| | S3_TOKEN_HOP = 640 |
| | S3_TOKEN_RATE = 25 |
| | SPEECH_VOCAB_SIZE = 6561 |
| |
|
| |
|
| | class S3Tokenizer(S3TokenizerV2): |
| | """ |
| | s3tokenizer.S3TokenizerV2 with the following changes: |
| | - a more integrated `forward` |
| | - compute `log_mel_spectrogram` using `_mel_filters` and `window` in `register_buffers` |
| | """ |
| |
|
| | ignore_state_dict_missing = ("_mel_filters", "window") |
| |
|
| | def __init__( |
| | self, |
| | name: str="speech_tokenizer_v2_25hz", |
| | config: ModelConfig = ModelConfig() |
| | ): |
| | super().__init__(name) |
| |
|
| | self.n_fft = 400 |
| | _mel_filters = librosa.filters.mel( |
| | sr=S3_SR, |
| | n_fft=self.n_fft, |
| | n_mels=config.n_mels |
| | ) |
| | self.register_buffer( |
| | "_mel_filters", |
| | torch.FloatTensor(_mel_filters), |
| | ) |
| |
|
| | self.register_buffer( |
| | "window", |
| | torch.hann_window(self.n_fft), |
| | ) |
| |
|
| | def pad(self, wavs, sr) -> List[torch.Tensor]: |
| | """ |
| | Given a list of wavs with the same `sample_rate`, pad them so that the length is multiple of 40ms (S3 runs at 25 token/sec). |
| | """ |
| | processed_wavs = [] |
| | for wav in wavs: |
| | if isinstance(wav, np.ndarray): |
| | wav = torch.from_numpy(wav) |
| | if wav.dim() == 1: |
| | wav = wav.unsqueeze(0) |
| |
|
| | n_tokens = (wav.shape[1] / sr) * S3_TOKEN_RATE |
| | n_tokens = np.ceil(n_tokens) |
| | intended_wav_len = n_tokens * (sr / S3_TOKEN_RATE) |
| | intended_wav_len = int(intended_wav_len) |
| | wav = torch.nn.functional.pad( |
| | wav, |
| | (0, intended_wav_len - wav.shape[-1]), |
| | mode="constant", |
| | value=0 |
| | ) |
| | processed_wavs.append(wav) |
| | return processed_wavs |
| |
|
| | def _prepare_audio(self, wavs): |
| | """Prepare a list of audios for s3tokenizer processing.""" |
| | processed_wavs = [] |
| | for wav in wavs: |
| | if isinstance(wav, np.ndarray): |
| | wav = torch.from_numpy(wav) |
| | if wav.dim() == 1: |
| | wav = wav.unsqueeze(0) |
| |
|
| | processed_wavs.append(wav) |
| | return processed_wavs |
| |
|
| | @torch.no_grad() |
| | def forward( |
| | self, |
| | wavs: torch.Tensor, |
| | accelerator: 'Accelerator'=None, |
| | max_len: int=None, |
| | ) -> Tuple[torch.Tensor, torch.LongTensor]: |
| | """ |
| | NOTE: mel-spec has a hop size of 160 points (100 frame/sec). |
| | FIXME: this class inherits `nn.Module` but doesn't accept `torch.Tensor` and handles a list of wavs one by one, which is unexpected. |
| | |
| | Args |
| | ---- |
| | - `wavs`: 16 kHz speech audio |
| | - `max_len` max length to truncate the output sequence to (25 token/sec). |
| | NOTE: please pad the waveform if longer sequence is needed. |
| | """ |
| | processed_wavs = self._prepare_audio(wavs) |
| | mels, mel_lens = [], [] |
| | for wav in processed_wavs: |
| | wav = wav.to(self.device) |
| | mel = self.log_mel_spectrogram(wav) |
| | if max_len is not None: |
| | mel = mel[..., :max_len * 4] |
| | mels.append(mel.squeeze(0)) |
| |
|
| | mels, mel_lens = padding(mels) |
| | if accelerator is None: |
| | tokenizer = self |
| | else: |
| | tokenizer = accelerator.unwrap_model(self) |
| |
|
| | speech_tokens, speech_token_lens = tokenizer.quantize(mels, mel_lens.to(self.device)) |
| | return ( |
| | speech_tokens.long().detach(), |
| | speech_token_lens.long().detach(), |
| | ) |
| |
|
| | def log_mel_spectrogram( |
| | self, |
| | audio: torch.Tensor, |
| | padding: int = 0, |
| | ): |
| | """ |
| | Compute the log-Mel spectrogram of |
| | |
| | Parameters |
| | ---------- |
| | audio: torch.Tensor, shape = (*) |
| | The path to audio or either a NumPy array or Tensor containing the |
| | audio waveform in 16 kHz |
| | |
| | padding: int |
| | Number of zero samples to pad to the right |
| | |
| | Returns |
| | ------- |
| | torch.Tensor, shape = (128, n_frames) |
| | A Tensor that contains the Mel spectrogram |
| | """ |
| | if not torch.is_tensor(audio): |
| | audio = torch.from_numpy(audio) |
| |
|
| | audio = audio.to(self.device) |
| | if padding > 0: |
| | audio = F.pad(audio, (0, padding)) |
| | stft = torch.stft( |
| | audio, self.n_fft, S3_HOP, |
| | window=self.window.to(self.device), |
| | return_complex=True |
| | ) |
| | magnitudes = stft[..., :-1].abs()**2 |
| |
|
| | mel_spec = self._mel_filters.to(self.device) @ magnitudes |
| |
|
| | log_spec = torch.clamp(mel_spec, min=1e-10).log10() |
| | log_spec = torch.maximum(log_spec, log_spec.max() - 8.0) |
| | log_spec = (log_spec + 4.0) / 4.0 |
| | return log_spec |
| |
|