| from attr import attr |
| import numpy as np |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
| class CustomSTFT(nn.Module): |
| """ |
| STFT/iSTFT without unfold/complex ops, using conv1d and conv_transpose1d. |
| |
| - forward STFT => Real-part conv1d + Imag-part conv1d |
| - inverse STFT => Real-part conv_transpose1d + Imag-part conv_transpose1d + sum |
| - avoids F.unfold, so easier to export to ONNX |
| - uses replicate or constant padding for 'center=True' to approximate 'reflect' |
| (reflect is not supported for dynamic shapes in ONNX) |
| """ |
|
|
| def __init__( |
| self, |
| filter_length=800, |
| hop_length=200, |
| win_length=800, |
| window="hann", |
| center=True, |
| pad_mode="replicate", |
| ): |
| super().__init__() |
| self.filter_length = filter_length |
| self.hop_length = hop_length |
| self.win_length = win_length |
| self.n_fft = filter_length |
| self.center = center |
| self.pad_mode = pad_mode |
|
|
| |
| self.freq_bins = self.n_fft // 2 + 1 |
|
|
| |
| assert window == 'hann', window |
| window_tensor = torch.hann_window(win_length, periodic=True, dtype=torch.float32) |
| if self.win_length < self.n_fft: |
| |
| extra = self.n_fft - self.win_length |
| window_tensor = F.pad(window_tensor, (0, extra)) |
| elif self.win_length > self.n_fft: |
| window_tensor = window_tensor[: self.n_fft] |
| self.register_buffer("window", window_tensor) |
|
|
| |
| |
| n = np.arange(self.n_fft) |
| k = np.arange(self.freq_bins) |
| angle = 2 * np.pi * np.outer(k, n) / self.n_fft |
| dft_real = np.cos(angle) |
| dft_imag = -np.sin(angle) |
|
|
| |
| |
| forward_window = window_tensor.numpy() |
| forward_real = dft_real * forward_window |
| forward_imag = dft_imag * forward_window |
|
|
| |
| forward_real_torch = torch.from_numpy(forward_real).float() |
| forward_imag_torch = torch.from_numpy(forward_imag).float() |
|
|
| |
| |
| self.register_buffer( |
| "weight_forward_real", forward_real_torch.unsqueeze(1) |
| ) |
| self.register_buffer( |
| "weight_forward_imag", forward_imag_torch.unsqueeze(1) |
| ) |
|
|
| |
| |
| |
| |
| |
| inv_scale = 1.0 / self.n_fft |
| n = np.arange(self.n_fft) |
| angle_t = 2 * np.pi * np.outer(n, k) / self.n_fft |
| idft_cos = np.cos(angle_t).T |
| idft_sin = np.sin(angle_t).T |
|
|
| |
| |
| inv_window = window_tensor.numpy() * inv_scale |
| backward_real = idft_cos * inv_window |
| backward_imag = idft_sin * inv_window |
|
|
| |
| self.register_buffer( |
| "weight_backward_real", torch.from_numpy(backward_real).float().unsqueeze(1) |
| ) |
| self.register_buffer( |
| "weight_backward_imag", torch.from_numpy(backward_imag).float().unsqueeze(1) |
| ) |
| |
|
|
|
|
| def transform(self, waveform: torch.Tensor): |
| """ |
| Forward STFT => returns magnitude, phase |
| Output shape => (batch, freq_bins, frames) |
| """ |
| |
| |
| if self.center: |
| pad_len = self.n_fft // 2 |
| waveform = F.pad(waveform, (pad_len, pad_len), mode=self.pad_mode) |
|
|
| x = waveform.unsqueeze(1) |
| |
| real_out = F.conv1d( |
| x, |
| self.weight_forward_real, |
| bias=None, |
| stride=self.hop_length, |
| padding=0, |
| ) |
| |
| imag_out = F.conv1d( |
| x, |
| self.weight_forward_imag, |
| bias=None, |
| stride=self.hop_length, |
| padding=0, |
| ) |
|
|
| |
| magnitude = torch.sqrt(real_out**2 + imag_out**2 + 1e-14) |
| phase = torch.atan2(imag_out, real_out) |
| |
| |
| correction_mask = (imag_out == 0) & (real_out < 0) |
| phase[correction_mask] = torch.pi |
| return magnitude, phase |
|
|
|
|
| def inverse(self, magnitude: torch.Tensor, phase: torch.Tensor, length=None): |
| """ |
| Inverse STFT => returns waveform shape (B, T). |
| """ |
| |
| |
| real_part = magnitude * torch.cos(phase) |
| imag_part = magnitude * torch.sin(phase) |
|
|
| |
| |
| |
| real_part = real_part |
| imag_part = imag_part |
|
|
| |
| |
| |
| real_rec = F.conv_transpose1d( |
| real_part, |
| self.weight_backward_real, |
| bias=None, |
| stride=self.hop_length, |
| padding=0, |
| ) |
| imag_rec = F.conv_transpose1d( |
| imag_part, |
| self.weight_backward_imag, |
| bias=None, |
| stride=self.hop_length, |
| padding=0, |
| ) |
| |
| waveform = real_rec - imag_rec |
|
|
| |
| if self.center: |
| pad_len = self.n_fft // 2 |
| |
| |
| waveform = waveform[..., pad_len:-pad_len] |
|
|
| |
| if length is not None: |
| waveform = waveform[..., :length] |
|
|
| |
| return waveform |
|
|
| def forward(self, x: torch.Tensor): |
| """ |
| Full STFT -> iSTFT pass: returns time-domain reconstruction. |
| Same interface as your original code. |
| """ |
| mag, phase = self.transform(x) |
| return self.inverse(mag, phase, length=x.shape[-1]) |
|
|