import torch import torch.nn as nn import torch.nn.functional as F def rfft_mag(x, n_fft: int): # x: [..., L](实数) X = torch.fft.rfft(x, n=n_fft) # [..., K] return X, X.abs() def omega_grid_like(spec): # spec: [..., K] K = spec.shape[-1] return torch.linspace(0.0, torch.pi, steps=K, device=spec.device) class NeuralWavelet1D(nn.Module): def __init__(self, channels: int, kernel_size: int = 81, share_across_channels: bool = True, n_fft: int = 256, band_power_p: float = 3.0): super().__init__() self.C = channels self.L = kernel_size self.share = share_across_channels self.n_fft = n_fft self.p = band_power_p shape = (1, 1, self.L) if share_across_channels else (self.C, 1, self.L) # 可学习 FIR 核(即冲激响应) self.g0 = nn.Parameter(torch.randn(*shape) * 0.05) # low-pass self.g1 = nn.Parameter(torch.randn(*shape) * 0.05) # high-pass # 可选:初始化成近似低/高通(这里简单零均值/单位和,工程中可用窗化-sinc) with torch.no_grad(): self.g1 -= self.g1.mean(dim=-1, keepdim=True) # 高通零均值启发式 def _filters(self): # 按通道广播 + 轻度范数归一,防止早期发散 if self.share: g0 = self.g0.expand(self.C, 1, self.L) g1 = self.g1.expand(self.C, 1, self.L) else: g0, g1 = self.g0, self.g1 g0 = g0 / (g0.norm(dim=(-2,-1), keepdim=True) + 1e-18) g1 = g1 / (g1.norm(dim=(-2,-1), keepdim=True) + 1e-18) return g0, g1 def forward(self, x): """ x: [B, C, T] 返回:y_low, y_high(不降采样子带) """ g0, g1 = self._filters() y_low = F.conv1d(x, g0, padding=self.L//2, groups=self.C) y_high = F.conv1d(x, g1, padding=self.L//2, groups=self.C) return y_low, y_high # ---------- 频域正则(数据无关,作用于滤波器频响) ---------- def spectral_losses(self): """ 返回: L_low: 低通带形先验(惩罚高频) -> g0 L_high: 高通带形先验(惩罚低频) -> g1 L_overlap: 频域不重叠(|G0| * |G1|) L_parseval: 框架/能量守恒近似 (|G0|^2 + |G1|^2 ≈ 1) L_shape: 形状约束(g1零均值、g0对称) """ g0, g1 = self._filters() # [C,1,L] g0m = g0.mean(0).squeeze(0) # [L] 用通道平均代表频响(也可逐通道再平均) g1m = g1.mean(0).squeeze(0) G0, A0 = rfft_mag(g0m, self.n_fft) # [K], [K] G1, A1 = rfft_mag(g1m, self.n_fft) omega = omega_grid_like(A0) # [K] # 带形先验权重(单调):低通在高频权重大;高通在低频权重大 w_low = (omega / torch.pi) ** self.p w_high = ((torch.pi - omega) / torch.pi) ** self.p L_low = (w_low * (A0**2)).mean() # 低通惩罚高频能量 L_high = (w_high * (A1**2)).mean() # 高通惩罚低频能量 # L_overlap = (A0 * A1).mean() # 频带不重叠(逐点乘) L_overlap = (A0**2 * A1**2).mean() # 频带不重叠(逐点乘) L_parseval = ((A0**2 + A1**2 - 1.0)**2).mean() # 形状约束:g1零均值(高通),g0近似对称(线性相位) L_zero_mean_high = g1.sum(dim=-1).mean().pow(2) L_sym_low = (g0 - torch.flip(g0, dims=[-1])).pow(2).mean() L_shape = L_zero_mean_high + L_sym_low return L_low, L_high, L_overlap, L_parseval, L_shape class LowGuidedHighCrossAttn(nn.Module): def __init__(self, d_model, nhead=4, dropout=0.0, use_gamma=True): super().__init__() self.nhead = nhead self.dk = d_model // nhead assert d_model % nhead == 0 self.q_proj = nn.Linear(d_model, d_model) self.k_proj = nn.Linear(d_model, d_model) self.v_proj = nn.Linear(d_model, d_model) self.out_proj = nn.Linear(d_model, d_model) nn.init.zeros_(self.out_proj.weight) nn.init.zeros_(self.out_proj.bias) self.dropout = nn.Dropout(dropout) self.use_gamma = use_gamma if use_gamma: self.gamma = nn.Parameter(torch.zeros(1)) # ReZero 风格:初始0更稳 def _split_heads(self, x): # [B,T,D] -> [B,h,T,dk] B,T,D = x.shape return x.view(B, T, self.nhead, self.dk).transpose(1,2) def _merge_heads(self, x): # [B,h,T,dk] -> [B,T,D] B,h,T,dk = x.shape return x.transpose(1,2).contiguous().view(B, T, h*dk) def forward(self, L, H, attn_mask=None, attn_bias=None): """ L: 低频主干 [B,T,d] H: 高频细节 [B,T,d] 返回: out[B,T,d], attn[B,h,T,T] """ Q = self._split_heads(self.q_proj(L)) K = self._split_heads(self.k_proj(L)) # 注意:Q/K都从低频来 V = self._split_heads(self.v_proj(H)) # 值来自高频 scores = torch.matmul(Q, K.transpose(-2, -1)) / (self.dk ** 0.5) # [B,h,T,T] if attn_bias is not None: scores = scores + attn_bias if attn_mask is not None: scores = scores + attn_mask attn = scores.softmax(dim=-1) # [B,h,T,T] attn = self.dropout(attn) Z = torch.matmul(attn, V) # [B,h,T,dk] Z = self._merge_heads(Z) # [B,T,d] Z = self.out_proj(Z) # [B,T,d] out = L + (0.01*torch.sigmoid(self.gamma) * Z if self.use_gamma else Z) return out, attn class WaveletFront(nn.Module): def __init__(self, in_channels=1, d_model=256, kernel_size=31, n_fft=256): super().__init__() self.wave = NeuralWavelet1D(in_channels, kernel_size, True, n_fft, band_power_p=3.0) # self.wave = MultiStageNeuralWavelet1D(in_channels, kernel_size, 1, None, False, n_fft, band_power_p=3.0) # self.wave = MultiStageNeuralWavelet1D(in_channels, kernel_size, 3, [1,2,4], False, n_fft, band_power_p=3.0) # self.wave = MultiStageNeuralWavelet1D(in_channels, kernel_size, 1, None, False, n_fft, band_power_p=3.0) self.proj_low = nn.Conv1d(in_channels, d_model, kernel_size=1) self.proj_high = nn.Conv1d(in_channels, d_model, kernel_size=1) self.cross = LowGuidedHighCrossAttn(d_model, nhead=4, dropout=0.1, use_gamma=True) def forward(self, x, attn_mask=None, attn_bias=None): """ x: [B,C,T] 返回: y: [B,T,d_model] -> 可直接 +PE/+TE 后送入 Transformer reg: dict 的各项正则损失 attn: cross-attn 权重矩阵(可解释/正则) """ y_low, y_high = self.wave(x) # [B,C,T], [B,C,T] # 特征层归一,防高频主导 y_low = F.layer_norm(y_low, y_low.shape[-1:]) y_high = F.layer_norm(y_high, y_high.shape[-1:]) # 投影到 d_model L = self.proj_low(y_low).transpose(1,2) # [B,T,d] H = self.proj_high(y_high).transpose(1,2) # [B,T,d] # 低→高 Cross-Attn 融合(输出与低频同维度) Y, attn = self.cross(L, H, attn_mask=attn_mask, attn_bias=attn_bias) # [B,T,d], [B,h,T,T] # 频域正则项(作用于滤波器) L_low, L_high, L_overlap, L_parseval, L_shape = self.wave.spectral_losses() reg = { "L_low": L_low, "L_high": L_high, "L_overlap": L_overlap, "L_parseval": L_parseval, "L_shape": L_shape } return Y, reg, attn def dilate_kernel_1d(g, dilation): # g: [..., L] -> 插零扩展成 [..., L_eff], L_eff = (L-1)*dilation + 1 L = g.shape[-1] if dilation == 1: return g L_eff = (L - 1) * dilation + 1 out = torch.zeros(*g.shape[:-1], L_eff, device=g.device, dtype=g.dtype) out[..., ::dilation] = g return out class MultiStageNeuralWavelet1D(nn.Module): """ 多层(级联)低/高通滤波器,不降采样。每层一个可学习 FIR 核,可设不同 dilation。 """ def __init__(self, channels: int, kernel_size: int = 31, stages: int = 3, dilations=None, share_across_channels: bool = True, n_fft: int = 512, band_power_p: float = 3.0): super().__init__() self.C = channels self.L = kernel_size self.S = stages self.share = share_across_channels self.n_fft = n_fft self.p = band_power_p if dilations is None: dilations = [1] * stages assert len(dilations) == stages self.dilations = dilations # 每层各一对滤波器核:g0_s(低通), g1_s(高通) self.g0_list = nn.ParameterList() self.g1_list = nn.ParameterList() shape = (1, 1, self.L) if share_across_channels else (self.C, 1, self.L) for _ in range(stages): g0 = nn.Parameter(torch.randn(*shape) * 0.05) g1 = nn.Parameter(torch.randn(*shape) * 0.05) # 高通零均值启发式初始化 with torch.no_grad(): g1 -= g1.mean(dim=-1, keepdim=True) self.g0_list.append(g0) self.g1_list.append(g1) def _filters_per_stage(self): """返回每层的正规化滤波器权重列表(按通道广播 + 轻度范数归一)""" g0s, g1s = [], [] for s in range(self.S): g0 = self.g0_list[s] g1 = self.g1_list[s] if self.share: g0 = g0.expand(self.C, 1, self.L) g1 = g1.expand(self.C, 1, self.L) # 轻度归一稳定训练 g0 = g0 / (g0.norm(dim=(-2,-1), keepdim=True) + 1e-18) g1 = g1 / (g1.norm(dim=(-2,-1), keepdim=True) + 1e-18) g0s.append(g0) g1s.append(g1) return g0s, g1s def forward(self, x): """ x: [B, C, T] 级联卷积:低通分支 y_low = g0_S * ... * g0_2 * g0_1 * x 高通分支 y_high = g1_S * ... * g1_2 * g1_1 * x """ g0s, g1s = self._filters_per_stage() y_low, y_high = x, x for s in range(self.S): d = self.dilations[s] pad = (self.L - 1) // 2 * d y_low = F.conv1d(y_low, g0s[s], padding=pad, dilation=d, groups=self.C) y_high = F.conv1d(y_high, g1s[s], padding=pad, dilation=d, groups=self.C) return y_low, y_high # -------- 频域正则:作用于“总频响”(各层频响相乘) -------- def spectral_losses(self): """ 返回: L_low, L_high: 低/高通带形先验(针对总频响) L_overlap: 频域不重叠(总频响幅度逐点乘) L_parseval: 近似能量守恒(|G0|^2 + |G1|^2 ≈ 1) L_shape: 形状约束(每层 g1 零均值、g0 近似对称) """ g0s, g1s = self._filters_per_stage() # 计算每层“插零后”的频响,再逐层相乘得到总频响 G0_total = None G1_total = None L_shape_terms = [] for s in range(self.S): d = self.dilations[s] # 通道平均成单核看频响(也可逐通道算后再平均) g0m = g0s[s].mean(0).squeeze(0) # [L] g1m = g1s[s].mean(0).squeeze(0) # [L] # 形状约束:高通零均值、低通对称(每层都约束,效果最好) L_zero_mean_high = g1m.sum().pow(2) L_sym_low = (g0m - torch.flip(g0m, dims=[-1])).pow(2).mean() L_shape_terms.append(L_zero_mean_high + L_sym_low) # 扩展空洞→真实等效冲激响应,再 FFT g0_eff = dilate_kernel_1d(g0m, d) g1_eff = dilate_kernel_1d(g1m, d) _, A0 = rfft_mag(g0_eff, self.n_fft) # [K] _, A1 = rfft_mag(g1_eff, self.n_fft) # 逐层频响幅度相乘(总幅度) G0_total = A0 if G0_total is None else (G0_total * A0) G1_total = A1 if G1_total is None else (G1_total * A1) omega = omega_grid_like(G0_total) # [0, π] w_low = (omega / torch.pi) ** self.p w_high = ((torch.pi - omega) / torch.pi) ** self.p L_low = (w_low * (G0_total**2)).mean() # 低通惩罚高频能量 L_high = (w_high * (G1_total**2)).mean() # 高通惩罚低频能量 L_overlap = (G0_total**2 * G1_total**2).mean() # L_overlap = (G0_total* G1_total).mean() L_parseval = ((G0_total**2 + G1_total**2 - 1.0)**2).mean() L_shape = torch.stack(L_shape_terms).mean() return L_low, L_high, L_overlap, L_parseval, L_shape