| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from typing import Optional |
| from conformer import Conformer |
|
|
| class NeuralModel(nn.Module): |
| """ |
| Принимает |X| STFT: (B, C, F, T_spec) и предсказывает комплексные маски |
| в свернутом виде: (B, 2 * (sources*channels), F, T_spec) |
| где 2 — это [real, imag]. |
| """ |
| def __init__( |
| self, |
| in_channels: int = 2, |
| sources: int = 2, |
| freq_bins: int = 2049, |
| embed_dim: int = 512, |
| depth: int = 8, |
| dim_head: int = 64, |
| heads: int = 8, |
| ff_mult: int = 4, |
| conv_expansion_factor: int = 2, |
| conv_kernel_size: int = 31, |
| attn_dropout: float = 0.1, |
| ff_dropout: float = 0.1, |
| conv_dropout: float = 0.1, |
| ): |
| super().__init__() |
| self.freq_bins = freq_bins |
| self.in_channels = in_channels |
| self.sources = sources |
| self.out_masks = sources * in_channels |
| self.embed_dim = embed_dim |
|
|
| self.input_proj_stft = nn.Linear(freq_bins * in_channels, embed_dim) |
| self.model = Conformer( |
| dim=embed_dim, |
| depth=depth, |
| dim_head=dim_head, |
| heads=heads, |
| ff_mult=ff_mult, |
| conv_expansion_factor=conv_expansion_factor, |
| conv_kernel_size=conv_kernel_size, |
| attn_dropout=attn_dropout, |
| ff_dropout=ff_dropout, |
| conv_dropout=conv_dropout, |
| ) |
| |
| self.output_proj = nn.Linear(embed_dim, freq_bins * self.out_masks * 2) |
|
|
| def forward(self, x_stft_mag: torch.Tensor) -> torch.Tensor: |
| """ |
| x_stft_mag: (B, C, F, T_spec) |
| returns: (B, 2 * (sources*channels), F, T_spec) — real/imag масок |
| """ |
| assert x_stft_mag.dim() == 4, f"Expected (B,C,F,T), got {tuple(x_stft_mag.shape)}" |
| B, C, F, T_spec = x_stft_mag.shape |
| |
| x_stft_mag = x_stft_mag.permute(0, 3, 1, 2).contiguous().view(B, T_spec, C * F) |
|
|
| x = self.input_proj_stft(x_stft_mag) |
| x = self.model(x) |
| x = torch.tanh(x) |
| x = self.output_proj(x) |
|
|
| |
| x = x.reshape(B, T_spec, self.out_masks * 2, F).permute(0, 2, 3, 1).contiguous() |
| return x |
|
|
|
|
| class ConformerMSS(nn.Module): |
| """ |
| Совместимо с твоим train: |
| forward(x: (B, C, T)) -> y_hat: (B, S, C, T) |
| где S = число источников (sources). |
| Внутри: STFT -> NeuralModel -> комплексные маски -> iSTFT. |
| """ |
| def __init__( |
| self, |
| core: NeuralModel, |
| n_fft: int = 4096, |
| hop_length: int = 1024, |
| win_length: Optional[int] = None, |
| center: bool = True, |
| ): |
| super().__init__() |
| self.core = core |
| self.n_fft = n_fft |
| self.hop_length = hop_length |
| self.win_length = win_length if win_length is not None else n_fft |
| self.center = center |
|
|
| window = torch.hann_window(self.win_length) |
| |
| self.register_buffer("window", window, persistent=False) |
|
|
| |
| expected_bins = n_fft // 2 + 1 |
| assert core.freq_bins == expected_bins, ( |
| f"NeuralModel.freq_bins={core.freq_bins} != n_fft//2+1={expected_bins}. " |
| f"Поставь freq_bins={expected_bins} при создании core." |
| ) |
|
|
| def _stft(self, x: torch.Tensor) -> torch.Tensor: |
| """ |
| x: (B, C, T) -> spec: complex (B, C, F, TT) |
| """ |
| assert x.dim() == 3, f"Expected (B,C,T), got {tuple(x.shape)}" |
| B, C, T = x.shape |
| x_bc_t = x.reshape(B * C, T) |
| spec = torch.stft( |
| x_bc_t, |
| n_fft=self.n_fft, |
| hop_length=self.hop_length, |
| win_length=self.win_length, |
| window=self.window.to(x.device), |
| center=self.center, |
| return_complex=True, |
| ) |
| F, TT = spec.shape[-2], spec.shape[-1] |
| spec = spec.reshape(B, C, F, TT) |
| return spec |
|
|
| def _istft(self, spec: torch.Tensor, length: int) -> torch.Tensor: |
| """ |
| spec: complex (B, C, F, TT) -> audio: (B, C, T) |
| """ |
| B, C, F, TT = spec.shape |
| spec_bc = spec.reshape(B * C, F, TT) |
| y_bc_t = torch.istft( |
| spec_bc, |
| n_fft=self.n_fft, |
| hop_length=self.hop_length, |
| win_length=self.win_length, |
| window=self.window.to(spec.device), |
| center=self.center, |
| length=length, |
| ) |
| return y_bc_t.reshape(B, C, -1) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| """ |
| x: (B, C, T) (микс в волне) |
| returns y_hat: (B, S, C, T) — предсказанные источники в волне |
| """ |
| B, C, T = x.shape |
| |
| mix_spec = self._stft(x) |
| mix_mag = mix_spec.abs() |
|
|
| |
| mask_ri = self.core(mix_mag) |
| _, two_sc, F, TT2 = mask_ri.shape |
|
|
| S = self.core.sources |
| assert two_sc == 2 * (S * C), ( |
| f"core вернул {two_sc} каналов масок, ожидалось {2*(S*C)} " |
| f"(2*[real/imag]*[sources*channels]). Проверь in_channels/sources." |
| ) |
|
|
| |
| TT = mix_spec.shape[-1] |
| TT_min = min(TT, TT2) |
| if TT != TT_min: |
| mix_spec = mix_spec[..., :TT_min] |
| if TT2 != TT_min: |
| mask_ri = mask_ri[..., :TT_min] |
| TT = TT_min |
| |
|
|
| |
| mask_ri = mask_ri.view(B, 2, S, C, F, TT).contiguous() |
| mask_real = mask_ri[:, 0] |
| mask_imag = mask_ri[:, 1] |
| masks_c = torch.complex(mask_real, mask_imag) |
|
|
| |
| mix_spec_bc = mix_spec.unsqueeze(1) |
| est_specs = masks_c * mix_spec_bc |
|
|
| |
| outs = [] |
| for s in range(S): |
| y_s = self._istft(est_specs[:, s], length=T) |
| outs.append(y_s) |
| y_hat = torch.stack(outs, dim=1) |
| return y_hat |