Spaces:
Running on Zero
Running on Zero
File size: 6,924 Bytes
64ec292 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 | from typing import Optional
import torch
import torch.nn as nn
from conformer import Conformer
class NeuralModel(nn.Module):
"""
Принимает |X| STFT: (B, C, F, T_spec) и предсказывает комплексные маски
в свернутом виде: (B, 2 * (sources*channels), F, T_spec)
где 2 — это [real, imag].
"""
def __init__(
self,
in_channels: int = 2,
sources: int = 2,
freq_bins: int = 2049,
embed_dim: int = 512,
depth: int = 8,
dim_head: int = 64,
heads: int = 8,
ff_mult: int = 4,
conv_expansion_factor: int = 2,
conv_kernel_size: int = 31,
attn_dropout: float = 0.1,
ff_dropout: float = 0.1,
conv_dropout: float = 0.1,
):
super().__init__()
self.freq_bins = freq_bins
self.in_channels = in_channels
self.sources = sources
self.out_masks = sources * in_channels
self.embed_dim = embed_dim
self.input_proj_stft = nn.Linear(freq_bins * in_channels, embed_dim)
self.model = Conformer(
dim=embed_dim,
depth=depth,
dim_head=dim_head,
heads=heads,
ff_mult=ff_mult,
conv_expansion_factor=conv_expansion_factor,
conv_kernel_size=conv_kernel_size,
attn_dropout=attn_dropout,
ff_dropout=ff_dropout,
conv_dropout=conv_dropout,
)
# 2 = [real, imag]
self.output_proj = nn.Linear(embed_dim, freq_bins * self.out_masks * 2)
def forward(self, x_stft_mag: torch.Tensor) -> torch.Tensor:
"""
x_stft_mag: (B, C, F, T_spec)
returns: (B, 2 * (sources*channels), F, T_spec) — real/imag масок
"""
assert x_stft_mag.dim() == 4, (
f"Expected (B,C,F,T), got {tuple(x_stft_mag.shape)}"
)
B, C, F, T_spec = x_stft_mag.shape
# (B, T_spec, C*F)
x_stft_mag = x_stft_mag.permute(0, 3, 1, 2).contiguous().view(B, T_spec, C * F)
x = self.input_proj_stft(x_stft_mag) # (B, T_spec, E)
x = self.model(x) # (B, T_spec, E)
x = torch.tanh(x) # стабилизируем
x = self.output_proj(x) # (B, T_spec, F * out_masks * 2)
# back to (B, 2*out_masks, F, T_spec)
x = x.reshape(B, T_spec, self.out_masks * 2, F).permute(0, 2, 3, 1).contiguous()
return x
class ConformerMSS(nn.Module):
"""
Совместимо с твоим train:
forward(x: (B, C, T)) -> y_hat: (B, S, C, T)
где S = число источников (sources).
Внутри: STFT -> NeuralModel -> комплексные маски -> iSTFT.
"""
def __init__(
self,
core: NeuralModel,
n_fft: int = 4096,
hop_length: int = 1024,
win_length: Optional[int] = None,
center: bool = True,
):
super().__init__()
self.core = core
self.n_fft = n_fft
self.hop_length = hop_length
self.win_length = win_length if win_length is not None else n_fft
self.center = center
window = torch.hann_window(self.win_length)
# окно — буфер, чтобы таскалось на .to(device)
self.register_buffer("window", window, persistent=False)
# sanity-check: freq_bins у core должен совпадать с n_fft//2 + 1
expected_bins = n_fft // 2 + 1
assert core.freq_bins == expected_bins, (
f"NeuralModel.freq_bins={core.freq_bins} != n_fft//2+1={expected_bins}. "
f"Поставь freq_bins={expected_bins} при создании core."
)
def _stft(self, x: torch.Tensor) -> torch.Tensor:
"""
x: (B, C, T) -> spec: complex (B, C, F, TT)
"""
assert x.dim() == 3, f"Expected (B,C,T), got {tuple(x.shape)}"
B, C, T = x.shape
x_bc_t = x.reshape(B * C, T)
spec = torch.stft(
x_bc_t,
n_fft=self.n_fft,
hop_length=self.hop_length,
win_length=self.win_length,
window=self.window.to(x.device),
center=self.center,
return_complex=True,
) # (B*C, F, TT)
F, TT = spec.shape[-2], spec.shape[-1]
spec = spec.reshape(B, C, F, TT)
return spec
def _istft(self, spec: torch.Tensor, length: int) -> torch.Tensor:
"""
spec: complex (B, C, F, TT) -> audio: (B, C, T)
"""
B, C, F, TT = spec.shape
spec_bc = spec.reshape(B * C, F, TT)
y_bc_t = torch.istft(
spec_bc,
n_fft=self.n_fft,
hop_length=self.hop_length,
win_length=self.win_length,
window=self.window.to(spec.device),
center=self.center,
length=length,
)
return y_bc_t.reshape(B, C, -1)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
x: (B, C, T) (микс в волне)
returns y_hat: (B, S, C, T) — предсказанные источники в волне
"""
B, C, T = x.shape
# 1) STFT
mix_spec = self._stft(x) # (B, C, F, TT)
mix_mag = mix_spec.abs() # (B, C, F, TT)
# 2) Прогон через core -> real/imag масок
mask_ri = self.core(mix_mag) # (B, 2*(S*C), F, TT2)
_, two_sc, F, TT2 = mask_ri.shape
S = self.core.sources
assert two_sc == 2 * (S * C), (
f"core вернул {two_sc} каналов масок, ожидалось {2 * (S * C)} "
f"(2*[real/imag]*[sources*channels]). Проверь in_channels/sources."
)
# 3) Синхронизация по времени (если вдруг TT != TT2)
TT = mix_spec.shape[-1]
TT_min = min(TT, TT2)
if TT != TT_min:
mix_spec = mix_spec[..., :TT_min]
if TT2 != TT_min:
mask_ri = mask_ri[..., :TT_min]
TT = TT_min
# теперь у обоих время = TT
# 4) Преобразуем к (B, 2, S, C, F, TT)
mask_ri = mask_ri.view(B, 2, S, C, F, TT).contiguous()
mask_real = mask_ri[:, 0] # (B, S, C, F, TT)
mask_imag = mask_ri[:, 1] # (B, S, C, F, TT)
masks_c = torch.complex(mask_real, mask_imag)
# 5) Применяем маски к комплексному спектру микса
mix_spec_bc = mix_spec.unsqueeze(1) # (B, 1, C, F, TT)
est_specs = masks_c * mix_spec_bc # (B, S, C, F, TT)
# 6) iSTFT по каждому источнику
outs = []
for s in range(S):
y_s = self._istft(est_specs[:, s], length=T) # (B, C, T)
outs.append(y_s)
y_hat = torch.stack(outs, dim=1) # (B, S, C, T)
return y_hat
|