xjsc0's picture
1
64ec292
from typing import Optional
import torch
import torch.nn as nn
from conformer import Conformer
class NeuralModel(nn.Module):
"""
Принимает |X| STFT: (B, C, F, T_spec) и предсказывает комплексные маски
в свернутом виде: (B, 2 * (sources*channels), F, T_spec)
где 2 — это [real, imag].
"""
def __init__(
self,
in_channels: int = 2,
sources: int = 2,
freq_bins: int = 2049,
embed_dim: int = 512,
depth: int = 8,
dim_head: int = 64,
heads: int = 8,
ff_mult: int = 4,
conv_expansion_factor: int = 2,
conv_kernel_size: int = 31,
attn_dropout: float = 0.1,
ff_dropout: float = 0.1,
conv_dropout: float = 0.1,
):
super().__init__()
self.freq_bins = freq_bins
self.in_channels = in_channels
self.sources = sources
self.out_masks = sources * in_channels
self.embed_dim = embed_dim
self.input_proj_stft = nn.Linear(freq_bins * in_channels, embed_dim)
self.model = Conformer(
dim=embed_dim,
depth=depth,
dim_head=dim_head,
heads=heads,
ff_mult=ff_mult,
conv_expansion_factor=conv_expansion_factor,
conv_kernel_size=conv_kernel_size,
attn_dropout=attn_dropout,
ff_dropout=ff_dropout,
conv_dropout=conv_dropout,
)
# 2 = [real, imag]
self.output_proj = nn.Linear(embed_dim, freq_bins * self.out_masks * 2)
def forward(self, x_stft_mag: torch.Tensor) -> torch.Tensor:
"""
x_stft_mag: (B, C, F, T_spec)
returns: (B, 2 * (sources*channels), F, T_spec) — real/imag масок
"""
assert x_stft_mag.dim() == 4, (
f"Expected (B,C,F,T), got {tuple(x_stft_mag.shape)}"
)
B, C, F, T_spec = x_stft_mag.shape
# (B, T_spec, C*F)
x_stft_mag = x_stft_mag.permute(0, 3, 1, 2).contiguous().view(B, T_spec, C * F)
x = self.input_proj_stft(x_stft_mag) # (B, T_spec, E)
x = self.model(x) # (B, T_spec, E)
x = torch.tanh(x) # стабилизируем
x = self.output_proj(x) # (B, T_spec, F * out_masks * 2)
# back to (B, 2*out_masks, F, T_spec)
x = x.reshape(B, T_spec, self.out_masks * 2, F).permute(0, 2, 3, 1).contiguous()
return x
class ConformerMSS(nn.Module):
"""
Совместимо с твоим train:
forward(x: (B, C, T)) -> y_hat: (B, S, C, T)
где S = число источников (sources).
Внутри: STFT -> NeuralModel -> комплексные маски -> iSTFT.
"""
def __init__(
self,
core: NeuralModel,
n_fft: int = 4096,
hop_length: int = 1024,
win_length: Optional[int] = None,
center: bool = True,
):
super().__init__()
self.core = core
self.n_fft = n_fft
self.hop_length = hop_length
self.win_length = win_length if win_length is not None else n_fft
self.center = center
window = torch.hann_window(self.win_length)
# окно — буфер, чтобы таскалось на .to(device)
self.register_buffer("window", window, persistent=False)
# sanity-check: freq_bins у core должен совпадать с n_fft//2 + 1
expected_bins = n_fft // 2 + 1
assert core.freq_bins == expected_bins, (
f"NeuralModel.freq_bins={core.freq_bins} != n_fft//2+1={expected_bins}. "
f"Поставь freq_bins={expected_bins} при создании core."
)
def _stft(self, x: torch.Tensor) -> torch.Tensor:
"""
x: (B, C, T) -> spec: complex (B, C, F, TT)
"""
assert x.dim() == 3, f"Expected (B,C,T), got {tuple(x.shape)}"
B, C, T = x.shape
x_bc_t = x.reshape(B * C, T)
spec = torch.stft(
x_bc_t,
n_fft=self.n_fft,
hop_length=self.hop_length,
win_length=self.win_length,
window=self.window.to(x.device),
center=self.center,
return_complex=True,
) # (B*C, F, TT)
F, TT = spec.shape[-2], spec.shape[-1]
spec = spec.reshape(B, C, F, TT)
return spec
def _istft(self, spec: torch.Tensor, length: int) -> torch.Tensor:
"""
spec: complex (B, C, F, TT) -> audio: (B, C, T)
"""
B, C, F, TT = spec.shape
spec_bc = spec.reshape(B * C, F, TT)
y_bc_t = torch.istft(
spec_bc,
n_fft=self.n_fft,
hop_length=self.hop_length,
win_length=self.win_length,
window=self.window.to(spec.device),
center=self.center,
length=length,
)
return y_bc_t.reshape(B, C, -1)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
x: (B, C, T) (микс в волне)
returns y_hat: (B, S, C, T) — предсказанные источники в волне
"""
B, C, T = x.shape
# 1) STFT
mix_spec = self._stft(x) # (B, C, F, TT)
mix_mag = mix_spec.abs() # (B, C, F, TT)
# 2) Прогон через core -> real/imag масок
mask_ri = self.core(mix_mag) # (B, 2*(S*C), F, TT2)
_, two_sc, F, TT2 = mask_ri.shape
S = self.core.sources
assert two_sc == 2 * (S * C), (
f"core вернул {two_sc} каналов масок, ожидалось {2 * (S * C)} "
f"(2*[real/imag]*[sources*channels]). Проверь in_channels/sources."
)
# 3) Синхронизация по времени (если вдруг TT != TT2)
TT = mix_spec.shape[-1]
TT_min = min(TT, TT2)
if TT != TT_min:
mix_spec = mix_spec[..., :TT_min]
if TT2 != TT_min:
mask_ri = mask_ri[..., :TT_min]
TT = TT_min
# теперь у обоих время = TT
# 4) Преобразуем к (B, 2, S, C, F, TT)
mask_ri = mask_ri.view(B, 2, S, C, F, TT).contiguous()
mask_real = mask_ri[:, 0] # (B, S, C, F, TT)
mask_imag = mask_ri[:, 1] # (B, S, C, F, TT)
masks_c = torch.complex(mask_real, mask_imag)
# 5) Применяем маски к комплексному спектру микса
mix_spec_bc = mix_spec.unsqueeze(1) # (B, 1, C, F, TT)
est_specs = masks_c * mix_spec_bc # (B, S, C, F, TT)
# 6) iSTFT по каждому источнику
outs = []
for s in range(S):
y_s = self._istft(est_specs[:, s], length=T) # (B, C, T)
outs.append(y_s)
y_hat = torch.stack(outs, dim=1) # (B, S, C, T)
return y_hat