| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
|
|
| def rfft_mag(x, n_fft: int): |
| |
| X = torch.fft.rfft(x, n=n_fft) |
| return X, X.abs() |
|
|
| def omega_grid_like(spec): |
| |
| K = spec.shape[-1] |
| return torch.linspace(0.0, torch.pi, steps=K, device=spec.device) |
|
|
| class NeuralWavelet1D(nn.Module): |
| def __init__(self, channels: int, kernel_size: int = 81, share_across_channels: bool = True, |
| n_fft: int = 256, band_power_p: float = 3.0): |
| super().__init__() |
| self.C = channels |
| self.L = kernel_size |
| self.share = share_across_channels |
| self.n_fft = n_fft |
| self.p = band_power_p |
|
|
| shape = (1, 1, self.L) if share_across_channels else (self.C, 1, self.L) |
| |
| self.g0 = nn.Parameter(torch.randn(*shape) * 0.05) |
| self.g1 = nn.Parameter(torch.randn(*shape) * 0.05) |
|
|
| |
| with torch.no_grad(): |
| self.g1 -= self.g1.mean(dim=-1, keepdim=True) |
|
|
| def _filters(self): |
| |
| if self.share: |
| g0 = self.g0.expand(self.C, 1, self.L) |
| g1 = self.g1.expand(self.C, 1, self.L) |
| else: |
| g0, g1 = self.g0, self.g1 |
| g0 = g0 / (g0.norm(dim=(-2,-1), keepdim=True) + 1e-18) |
| g1 = g1 / (g1.norm(dim=(-2,-1), keepdim=True) + 1e-18) |
| return g0, g1 |
|
|
| def forward(self, x): |
| """ |
| x: [B, C, T] |
| 返回:y_low, y_high(不降采样子带) |
| """ |
| g0, g1 = self._filters() |
| y_low = F.conv1d(x, g0, padding=self.L//2, groups=self.C) |
| y_high = F.conv1d(x, g1, padding=self.L//2, groups=self.C) |
| return y_low, y_high |
|
|
| |
| def spectral_losses(self): |
| """ |
| 返回: |
| L_low: 低通带形先验(惩罚高频) -> g0 |
| L_high: 高通带形先验(惩罚低频) -> g1 |
| L_overlap: 频域不重叠(|G0| * |G1|) |
| L_parseval: 框架/能量守恒近似 (|G0|^2 + |G1|^2 ≈ 1) |
| L_shape: 形状约束(g1零均值、g0对称) |
| """ |
| g0, g1 = self._filters() |
| g0m = g0.mean(0).squeeze(0) |
| g1m = g1.mean(0).squeeze(0) |
|
|
| G0, A0 = rfft_mag(g0m, self.n_fft) |
| G1, A1 = rfft_mag(g1m, self.n_fft) |
|
|
| omega = omega_grid_like(A0) |
| |
| w_low = (omega / torch.pi) ** self.p |
| w_high = ((torch.pi - omega) / torch.pi) ** self.p |
|
|
| L_low = (w_low * (A0**2)).mean() |
| L_high = (w_high * (A1**2)).mean() |
| |
| L_overlap = (A0**2 * A1**2).mean() |
| L_parseval = ((A0**2 + A1**2 - 1.0)**2).mean() |
|
|
| |
| L_zero_mean_high = g1.sum(dim=-1).mean().pow(2) |
| L_sym_low = (g0 - torch.flip(g0, dims=[-1])).pow(2).mean() |
| L_shape = L_zero_mean_high + L_sym_low |
|
|
| return L_low, L_high, L_overlap, L_parseval, L_shape |
|
|
| class LowGuidedHighCrossAttn(nn.Module): |
| def __init__(self, d_model, nhead=4, dropout=0.0, use_gamma=True): |
| super().__init__() |
| self.nhead = nhead |
| self.dk = d_model // nhead |
| assert d_model % nhead == 0 |
|
|
| self.q_proj = nn.Linear(d_model, d_model) |
| self.k_proj = nn.Linear(d_model, d_model) |
| self.v_proj = nn.Linear(d_model, d_model) |
| self.out_proj = nn.Linear(d_model, d_model) |
| nn.init.zeros_(self.out_proj.weight) |
| nn.init.zeros_(self.out_proj.bias) |
|
|
| self.dropout = nn.Dropout(dropout) |
| self.use_gamma = use_gamma |
| if use_gamma: |
| self.gamma = nn.Parameter(torch.zeros(1)) |
|
|
| def _split_heads(self, x): |
| B,T,D = x.shape |
| return x.view(B, T, self.nhead, self.dk).transpose(1,2) |
|
|
| def _merge_heads(self, x): |
| B,h,T,dk = x.shape |
| return x.transpose(1,2).contiguous().view(B, T, h*dk) |
|
|
| def forward(self, L, H, attn_mask=None, attn_bias=None): |
| """ |
| L: 低频主干 [B,T,d] |
| H: 高频细节 [B,T,d] |
| 返回: out[B,T,d], attn[B,h,T,T] |
| """ |
| Q = self._split_heads(self.q_proj(L)) |
| K = self._split_heads(self.k_proj(L)) |
| V = self._split_heads(self.v_proj(H)) |
|
|
| scores = torch.matmul(Q, K.transpose(-2, -1)) / (self.dk ** 0.5) |
| if attn_bias is not None: |
| scores = scores + attn_bias |
| if attn_mask is not None: |
| scores = scores + attn_mask |
|
|
| attn = scores.softmax(dim=-1) |
| attn = self.dropout(attn) |
|
|
| Z = torch.matmul(attn, V) |
| Z = self._merge_heads(Z) |
| Z = self.out_proj(Z) |
|
|
| out = L + (0.01*torch.sigmoid(self.gamma) * Z if self.use_gamma else Z) |
| return out, attn |
|
|
| class WaveletFront(nn.Module): |
| def __init__(self, in_channels=1, d_model=256, kernel_size=31, n_fft=256): |
| super().__init__() |
| self.wave = NeuralWavelet1D(in_channels, kernel_size, True, n_fft, band_power_p=3.0) |
| |
| |
| |
| self.proj_low = nn.Conv1d(in_channels, d_model, kernel_size=1) |
| self.proj_high = nn.Conv1d(in_channels, d_model, kernel_size=1) |
| self.cross = LowGuidedHighCrossAttn(d_model, nhead=4, dropout=0.1, use_gamma=True) |
|
|
| def forward(self, x, attn_mask=None, attn_bias=None): |
| """ |
| x: [B,C,T] |
| 返回: |
| y: [B,T,d_model] -> 可直接 +PE/+TE 后送入 Transformer |
| reg: dict 的各项正则损失 |
| attn: cross-attn 权重矩阵(可解释/正则) |
| """ |
| y_low, y_high = self.wave(x) |
| |
| y_low = F.layer_norm(y_low, y_low.shape[-1:]) |
| y_high = F.layer_norm(y_high, y_high.shape[-1:]) |
|
|
| |
| L = self.proj_low(y_low).transpose(1,2) |
| H = self.proj_high(y_high).transpose(1,2) |
|
|
| |
| Y, attn = self.cross(L, H, attn_mask=attn_mask, attn_bias=attn_bias) |
|
|
| |
| L_low, L_high, L_overlap, L_parseval, L_shape = self.wave.spectral_losses() |
| reg = { |
| "L_low": L_low, "L_high": L_high, |
| "L_overlap": L_overlap, "L_parseval": L_parseval, |
| "L_shape": L_shape |
| } |
| return Y, reg, attn |
|
|
| def dilate_kernel_1d(g, dilation): |
| |
| L = g.shape[-1] |
| if dilation == 1: |
| return g |
| L_eff = (L - 1) * dilation + 1 |
| out = torch.zeros(*g.shape[:-1], L_eff, device=g.device, dtype=g.dtype) |
| out[..., ::dilation] = g |
| return out |
|
|
| class MultiStageNeuralWavelet1D(nn.Module): |
| """ |
| 多层(级联)低/高通滤波器,不降采样。每层一个可学习 FIR 核,可设不同 dilation。 |
| """ |
| def __init__(self, channels: int, kernel_size: int = 31, stages: int = 3, |
| dilations=None, share_across_channels: bool = True, |
| n_fft: int = 512, band_power_p: float = 3.0): |
| super().__init__() |
| self.C = channels |
| self.L = kernel_size |
| self.S = stages |
| self.share = share_across_channels |
| self.n_fft = n_fft |
| self.p = band_power_p |
|
|
| if dilations is None: |
| dilations = [1] * stages |
| assert len(dilations) == stages |
| self.dilations = dilations |
|
|
| |
| self.g0_list = nn.ParameterList() |
| self.g1_list = nn.ParameterList() |
| shape = (1, 1, self.L) if share_across_channels else (self.C, 1, self.L) |
|
|
| for _ in range(stages): |
| g0 = nn.Parameter(torch.randn(*shape) * 0.05) |
| g1 = nn.Parameter(torch.randn(*shape) * 0.05) |
| |
| with torch.no_grad(): |
| g1 -= g1.mean(dim=-1, keepdim=True) |
| self.g0_list.append(g0) |
| self.g1_list.append(g1) |
|
|
| def _filters_per_stage(self): |
| """返回每层的正规化滤波器权重列表(按通道广播 + 轻度范数归一)""" |
| g0s, g1s = [], [] |
| for s in range(self.S): |
| g0 = self.g0_list[s] |
| g1 = self.g1_list[s] |
| if self.share: |
| g0 = g0.expand(self.C, 1, self.L) |
| g1 = g1.expand(self.C, 1, self.L) |
| |
| g0 = g0 / (g0.norm(dim=(-2,-1), keepdim=True) + 1e-18) |
| g1 = g1 / (g1.norm(dim=(-2,-1), keepdim=True) + 1e-18) |
| g0s.append(g0) |
| g1s.append(g1) |
| return g0s, g1s |
|
|
| def forward(self, x): |
| """ |
| x: [B, C, T] |
| 级联卷积:低通分支 y_low = g0_S * ... * g0_2 * g0_1 * x |
| 高通分支 y_high = g1_S * ... * g1_2 * g1_1 * x |
| """ |
| g0s, g1s = self._filters_per_stage() |
| y_low, y_high = x, x |
| for s in range(self.S): |
| d = self.dilations[s] |
| pad = (self.L - 1) // 2 * d |
| y_low = F.conv1d(y_low, g0s[s], padding=pad, dilation=d, groups=self.C) |
| y_high = F.conv1d(y_high, g1s[s], padding=pad, dilation=d, groups=self.C) |
| return y_low, y_high |
|
|
| |
| def spectral_losses(self): |
| """ |
| 返回: |
| L_low, L_high: 低/高通带形先验(针对总频响) |
| L_overlap: 频域不重叠(总频响幅度逐点乘) |
| L_parseval: 近似能量守恒(|G0|^2 + |G1|^2 ≈ 1) |
| L_shape: 形状约束(每层 g1 零均值、g0 近似对称) |
| """ |
| g0s, g1s = self._filters_per_stage() |
|
|
| |
| G0_total = None |
| G1_total = None |
| L_shape_terms = [] |
|
|
| for s in range(self.S): |
| d = self.dilations[s] |
| |
| g0m = g0s[s].mean(0).squeeze(0) |
| g1m = g1s[s].mean(0).squeeze(0) |
|
|
| |
| L_zero_mean_high = g1m.sum().pow(2) |
| L_sym_low = (g0m - torch.flip(g0m, dims=[-1])).pow(2).mean() |
| L_shape_terms.append(L_zero_mean_high + L_sym_low) |
|
|
| |
| g0_eff = dilate_kernel_1d(g0m, d) |
| g1_eff = dilate_kernel_1d(g1m, d) |
| _, A0 = rfft_mag(g0_eff, self.n_fft) |
| _, A1 = rfft_mag(g1_eff, self.n_fft) |
|
|
| |
| G0_total = A0 if G0_total is None else (G0_total * A0) |
| G1_total = A1 if G1_total is None else (G1_total * A1) |
|
|
| omega = omega_grid_like(G0_total) |
| w_low = (omega / torch.pi) ** self.p |
| w_high = ((torch.pi - omega) / torch.pi) ** self.p |
|
|
| L_low = (w_low * (G0_total**2)).mean() |
| L_high = (w_high * (G1_total**2)).mean() |
| L_overlap = (G0_total**2 * G1_total**2).mean() |
| |
| L_parseval = ((G0_total**2 + G1_total**2 - 1.0)**2).mean() |
| L_shape = torch.stack(L_shape_terms).mean() |
|
|
| return L_low, L_high, L_overlap, L_parseval, L_shape |
|
|