| import librosa |
| import torch |
| from typing import Tuple |
|
|
| from espnet.nets.pytorch_backend.nets_utils import make_pad_mask |
|
|
|
|
| class LogMel(torch.nn.Module): |
| """Convert STFT to fbank feats |
| |
| The arguments is same as librosa.filters.mel |
| |
| Args: |
| fs: number > 0 [scalar] sampling rate of the incoming signal |
| n_fft: int > 0 [scalar] number of FFT components |
| n_mels: int > 0 [scalar] number of Mel bands to generate |
| fmin: float >= 0 [scalar] lowest frequency (in Hz) |
| fmax: float >= 0 [scalar] highest frequency (in Hz). |
| If `None`, use `fmax = fs / 2.0` |
| htk: use HTK formula instead of Slaney |
| """ |
|
|
| def __init__( |
| self, |
| fs: int = 16000, |
| n_fft: int = 512, |
| n_mels: int = 80, |
| fmin: float = None, |
| fmax: float = None, |
| htk: bool = False, |
| log_base: float = None, |
| ): |
| super().__init__() |
|
|
| fmin = 0 if fmin is None else fmin |
| fmax = fs / 2 if fmax is None else fmax |
| _mel_options = dict( |
| sr=fs, |
| n_fft=n_fft, |
| n_mels=n_mels, |
| fmin=fmin, |
| fmax=fmax, |
| htk=htk, |
| ) |
| self.mel_options = _mel_options |
| self.log_base = log_base |
|
|
| |
| melmat = librosa.filters.mel(**_mel_options) |
| |
| self.register_buffer("melmat", torch.from_numpy(melmat.T).float()) |
|
|
| def extra_repr(self): |
| return ", ".join(f"{k}={v}" for k, v in self.mel_options.items()) |
|
|
| def forward( |
| self, |
| feat: torch.Tensor, |
| ilens: torch.Tensor = None, |
| ) -> Tuple[torch.Tensor, torch.Tensor]: |
| |
| mel_feat = torch.matmul(feat, self.melmat) |
| mel_feat = torch.clamp(mel_feat, min=1e-10) |
|
|
| if self.log_base is None: |
| logmel_feat = mel_feat.log() |
| elif self.log_base == 2.0: |
| logmel_feat = mel_feat.log2() |
| elif self.log_base == 10.0: |
| logmel_feat = mel_feat.log10() |
| else: |
| logmel_feat = mel_feat.log() / torch.log(self.log_base) |
|
|
| |
| if ilens is not None: |
| logmel_feat = logmel_feat.masked_fill( |
| make_pad_mask(ilens, logmel_feat, 1), 0.0 |
| ) |
| else: |
| ilens = feat.new_full( |
| [feat.size(0)], fill_value=feat.size(1), dtype=torch.long |
| ) |
| return logmel_feat, ilens |
|
|