from typing import Optional import torch import torch.nn as nn from conformer import Conformer class NeuralModel(nn.Module): """ Принимает |X| STFT: (B, C, F, T_spec) и предсказывает комплексные маски в свернутом виде: (B, 2 * (sources*channels), F, T_spec) где 2 — это [real, imag]. """ def __init__( self, in_channels: int = 2, sources: int = 2, freq_bins: int = 2049, embed_dim: int = 512, depth: int = 8, dim_head: int = 64, heads: int = 8, ff_mult: int = 4, conv_expansion_factor: int = 2, conv_kernel_size: int = 31, attn_dropout: float = 0.1, ff_dropout: float = 0.1, conv_dropout: float = 0.1, ): super().__init__() self.freq_bins = freq_bins self.in_channels = in_channels self.sources = sources self.out_masks = sources * in_channels self.embed_dim = embed_dim self.input_proj_stft = nn.Linear(freq_bins * in_channels, embed_dim) self.model = Conformer( dim=embed_dim, depth=depth, dim_head=dim_head, heads=heads, ff_mult=ff_mult, conv_expansion_factor=conv_expansion_factor, conv_kernel_size=conv_kernel_size, attn_dropout=attn_dropout, ff_dropout=ff_dropout, conv_dropout=conv_dropout, ) # 2 = [real, imag] self.output_proj = nn.Linear(embed_dim, freq_bins * self.out_masks * 2) def forward(self, x_stft_mag: torch.Tensor) -> torch.Tensor: """ x_stft_mag: (B, C, F, T_spec) returns: (B, 2 * (sources*channels), F, T_spec) — real/imag масок """ assert x_stft_mag.dim() == 4, ( f"Expected (B,C,F,T), got {tuple(x_stft_mag.shape)}" ) B, C, F, T_spec = x_stft_mag.shape # (B, T_spec, C*F) x_stft_mag = x_stft_mag.permute(0, 3, 1, 2).contiguous().view(B, T_spec, C * F) x = self.input_proj_stft(x_stft_mag) # (B, T_spec, E) x = self.model(x) # (B, T_spec, E) x = torch.tanh(x) # стабилизируем x = self.output_proj(x) # (B, T_spec, F * out_masks * 2) # back to (B, 2*out_masks, F, T_spec) x = x.reshape(B, T_spec, self.out_masks * 2, F).permute(0, 2, 3, 1).contiguous() return x class ConformerMSS(nn.Module): """ Совместимо с твоим train: forward(x: (B, C, T)) -> y_hat: (B, S, C, T) где S = число источников (sources). Внутри: STFT -> NeuralModel -> комплексные маски -> iSTFT. """ def __init__( self, core: NeuralModel, n_fft: int = 4096, hop_length: int = 1024, win_length: Optional[int] = None, center: bool = True, ): super().__init__() self.core = core self.n_fft = n_fft self.hop_length = hop_length self.win_length = win_length if win_length is not None else n_fft self.center = center window = torch.hann_window(self.win_length) # окно — буфер, чтобы таскалось на .to(device) self.register_buffer("window", window, persistent=False) # sanity-check: freq_bins у core должен совпадать с n_fft//2 + 1 expected_bins = n_fft // 2 + 1 assert core.freq_bins == expected_bins, ( f"NeuralModel.freq_bins={core.freq_bins} != n_fft//2+1={expected_bins}. " f"Поставь freq_bins={expected_bins} при создании core." ) def _stft(self, x: torch.Tensor) -> torch.Tensor: """ x: (B, C, T) -> spec: complex (B, C, F, TT) """ assert x.dim() == 3, f"Expected (B,C,T), got {tuple(x.shape)}" B, C, T = x.shape x_bc_t = x.reshape(B * C, T) spec = torch.stft( x_bc_t, n_fft=self.n_fft, hop_length=self.hop_length, win_length=self.win_length, window=self.window.to(x.device), center=self.center, return_complex=True, ) # (B*C, F, TT) F, TT = spec.shape[-2], spec.shape[-1] spec = spec.reshape(B, C, F, TT) return spec def _istft(self, spec: torch.Tensor, length: int) -> torch.Tensor: """ spec: complex (B, C, F, TT) -> audio: (B, C, T) """ B, C, F, TT = spec.shape spec_bc = spec.reshape(B * C, F, TT) y_bc_t = torch.istft( spec_bc, n_fft=self.n_fft, hop_length=self.hop_length, win_length=self.win_length, window=self.window.to(spec.device), center=self.center, length=length, ) return y_bc_t.reshape(B, C, -1) def forward(self, x: torch.Tensor) -> torch.Tensor: """ x: (B, C, T) (микс в волне) returns y_hat: (B, S, C, T) — предсказанные источники в волне """ B, C, T = x.shape # 1) STFT mix_spec = self._stft(x) # (B, C, F, TT) mix_mag = mix_spec.abs() # (B, C, F, TT) # 2) Прогон через core -> real/imag масок mask_ri = self.core(mix_mag) # (B, 2*(S*C), F, TT2) _, two_sc, F, TT2 = mask_ri.shape S = self.core.sources assert two_sc == 2 * (S * C), ( f"core вернул {two_sc} каналов масок, ожидалось {2 * (S * C)} " f"(2*[real/imag]*[sources*channels]). Проверь in_channels/sources." ) # 3) Синхронизация по времени (если вдруг TT != TT2) TT = mix_spec.shape[-1] TT_min = min(TT, TT2) if TT != TT_min: mix_spec = mix_spec[..., :TT_min] if TT2 != TT_min: mask_ri = mask_ri[..., :TT_min] TT = TT_min # теперь у обоих время = TT # 4) Преобразуем к (B, 2, S, C, F, TT) mask_ri = mask_ri.view(B, 2, S, C, F, TT).contiguous() mask_real = mask_ri[:, 0] # (B, S, C, F, TT) mask_imag = mask_ri[:, 1] # (B, S, C, F, TT) masks_c = torch.complex(mask_real, mask_imag) # 5) Применяем маски к комплексному спектру микса mix_spec_bc = mix_spec.unsqueeze(1) # (B, 1, C, F, TT) est_specs = masks_c * mix_spec_bc # (B, S, C, F, TT) # 6) iSTFT по каждому источнику outs = [] for s in range(S): y_s = self._istft(est_specs[:, s], length=T) # (B, C, T) outs.append(y_s) y_hat = torch.stack(outs, dim=1) # (B, S, C, T) return y_hat