File size: 12,854 Bytes
093b0a5 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 | import torch
import torch.nn as nn
import torch.nn.functional as F
def rfft_mag(x, n_fft: int):
# x: [..., L](实数)
X = torch.fft.rfft(x, n=n_fft) # [..., K]
return X, X.abs()
def omega_grid_like(spec):
# spec: [..., K]
K = spec.shape[-1]
return torch.linspace(0.0, torch.pi, steps=K, device=spec.device)
class NeuralWavelet1D(nn.Module):
def __init__(self, channels: int, kernel_size: int = 81, share_across_channels: bool = True,
n_fft: int = 256, band_power_p: float = 3.0):
super().__init__()
self.C = channels
self.L = kernel_size
self.share = share_across_channels
self.n_fft = n_fft
self.p = band_power_p
shape = (1, 1, self.L) if share_across_channels else (self.C, 1, self.L)
# 可学习 FIR 核(即冲激响应)
self.g0 = nn.Parameter(torch.randn(*shape) * 0.05) # low-pass
self.g1 = nn.Parameter(torch.randn(*shape) * 0.05) # high-pass
# 可选:初始化成近似低/高通(这里简单零均值/单位和,工程中可用窗化-sinc)
with torch.no_grad():
self.g1 -= self.g1.mean(dim=-1, keepdim=True) # 高通零均值启发式
def _filters(self):
# 按通道广播 + 轻度范数归一,防止早期发散
if self.share:
g0 = self.g0.expand(self.C, 1, self.L)
g1 = self.g1.expand(self.C, 1, self.L)
else:
g0, g1 = self.g0, self.g1
g0 = g0 / (g0.norm(dim=(-2,-1), keepdim=True) + 1e-18)
g1 = g1 / (g1.norm(dim=(-2,-1), keepdim=True) + 1e-18)
return g0, g1
def forward(self, x):
"""
x: [B, C, T]
返回:y_low, y_high(不降采样子带)
"""
g0, g1 = self._filters()
y_low = F.conv1d(x, g0, padding=self.L//2, groups=self.C)
y_high = F.conv1d(x, g1, padding=self.L//2, groups=self.C)
return y_low, y_high
# ---------- 频域正则(数据无关,作用于滤波器频响) ----------
def spectral_losses(self):
"""
返回:
L_low: 低通带形先验(惩罚高频) -> g0
L_high: 高通带形先验(惩罚低频) -> g1
L_overlap: 频域不重叠(|G0| * |G1|)
L_parseval: 框架/能量守恒近似 (|G0|^2 + |G1|^2 ≈ 1)
L_shape: 形状约束(g1零均值、g0对称)
"""
g0, g1 = self._filters() # [C,1,L]
g0m = g0.mean(0).squeeze(0) # [L] 用通道平均代表频响(也可逐通道再平均)
g1m = g1.mean(0).squeeze(0)
G0, A0 = rfft_mag(g0m, self.n_fft) # [K], [K]
G1, A1 = rfft_mag(g1m, self.n_fft)
omega = omega_grid_like(A0) # [K]
# 带形先验权重(单调):低通在高频权重大;高通在低频权重大
w_low = (omega / torch.pi) ** self.p
w_high = ((torch.pi - omega) / torch.pi) ** self.p
L_low = (w_low * (A0**2)).mean() # 低通惩罚高频能量
L_high = (w_high * (A1**2)).mean() # 高通惩罚低频能量
# L_overlap = (A0 * A1).mean() # 频带不重叠(逐点乘)
L_overlap = (A0**2 * A1**2).mean() # 频带不重叠(逐点乘)
L_parseval = ((A0**2 + A1**2 - 1.0)**2).mean()
# 形状约束:g1零均值(高通),g0近似对称(线性相位)
L_zero_mean_high = g1.sum(dim=-1).mean().pow(2)
L_sym_low = (g0 - torch.flip(g0, dims=[-1])).pow(2).mean()
L_shape = L_zero_mean_high + L_sym_low
return L_low, L_high, L_overlap, L_parseval, L_shape
class LowGuidedHighCrossAttn(nn.Module):
def __init__(self, d_model, nhead=4, dropout=0.0, use_gamma=True):
super().__init__()
self.nhead = nhead
self.dk = d_model // nhead
assert d_model % nhead == 0
self.q_proj = nn.Linear(d_model, d_model)
self.k_proj = nn.Linear(d_model, d_model)
self.v_proj = nn.Linear(d_model, d_model)
self.out_proj = nn.Linear(d_model, d_model)
nn.init.zeros_(self.out_proj.weight)
nn.init.zeros_(self.out_proj.bias)
self.dropout = nn.Dropout(dropout)
self.use_gamma = use_gamma
if use_gamma:
self.gamma = nn.Parameter(torch.zeros(1)) # ReZero 风格:初始0更稳
def _split_heads(self, x): # [B,T,D] -> [B,h,T,dk]
B,T,D = x.shape
return x.view(B, T, self.nhead, self.dk).transpose(1,2)
def _merge_heads(self, x): # [B,h,T,dk] -> [B,T,D]
B,h,T,dk = x.shape
return x.transpose(1,2).contiguous().view(B, T, h*dk)
def forward(self, L, H, attn_mask=None, attn_bias=None):
"""
L: 低频主干 [B,T,d]
H: 高频细节 [B,T,d]
返回: out[B,T,d], attn[B,h,T,T]
"""
Q = self._split_heads(self.q_proj(L))
K = self._split_heads(self.k_proj(L)) # 注意:Q/K都从低频来
V = self._split_heads(self.v_proj(H)) # 值来自高频
scores = torch.matmul(Q, K.transpose(-2, -1)) / (self.dk ** 0.5) # [B,h,T,T]
if attn_bias is not None:
scores = scores + attn_bias
if attn_mask is not None:
scores = scores + attn_mask
attn = scores.softmax(dim=-1) # [B,h,T,T]
attn = self.dropout(attn)
Z = torch.matmul(attn, V) # [B,h,T,dk]
Z = self._merge_heads(Z) # [B,T,d]
Z = self.out_proj(Z) # [B,T,d]
out = L + (0.01*torch.sigmoid(self.gamma) * Z if self.use_gamma else Z)
return out, attn
class WaveletFront(nn.Module):
def __init__(self, in_channels=1, d_model=256, kernel_size=31, n_fft=256):
super().__init__()
self.wave = NeuralWavelet1D(in_channels, kernel_size, True, n_fft, band_power_p=3.0)
# self.wave = MultiStageNeuralWavelet1D(in_channels, kernel_size, 1, None, False, n_fft, band_power_p=3.0)
# self.wave = MultiStageNeuralWavelet1D(in_channels, kernel_size, 3, [1,2,4], False, n_fft, band_power_p=3.0)
# self.wave = MultiStageNeuralWavelet1D(in_channels, kernel_size, 1, None, False, n_fft, band_power_p=3.0)
self.proj_low = nn.Conv1d(in_channels, d_model, kernel_size=1)
self.proj_high = nn.Conv1d(in_channels, d_model, kernel_size=1)
self.cross = LowGuidedHighCrossAttn(d_model, nhead=4, dropout=0.1, use_gamma=True)
def forward(self, x, attn_mask=None, attn_bias=None):
"""
x: [B,C,T]
返回:
y: [B,T,d_model] -> 可直接 +PE/+TE 后送入 Transformer
reg: dict 的各项正则损失
attn: cross-attn 权重矩阵(可解释/正则)
"""
y_low, y_high = self.wave(x) # [B,C,T], [B,C,T]
# 特征层归一,防高频主导
y_low = F.layer_norm(y_low, y_low.shape[-1:])
y_high = F.layer_norm(y_high, y_high.shape[-1:])
# 投影到 d_model
L = self.proj_low(y_low).transpose(1,2) # [B,T,d]
H = self.proj_high(y_high).transpose(1,2) # [B,T,d]
# 低→高 Cross-Attn 融合(输出与低频同维度)
Y, attn = self.cross(L, H, attn_mask=attn_mask, attn_bias=attn_bias) # [B,T,d], [B,h,T,T]
# 频域正则项(作用于滤波器)
L_low, L_high, L_overlap, L_parseval, L_shape = self.wave.spectral_losses()
reg = {
"L_low": L_low, "L_high": L_high,
"L_overlap": L_overlap, "L_parseval": L_parseval,
"L_shape": L_shape
}
return Y, reg, attn
def dilate_kernel_1d(g, dilation):
# g: [..., L] -> 插零扩展成 [..., L_eff], L_eff = (L-1)*dilation + 1
L = g.shape[-1]
if dilation == 1:
return g
L_eff = (L - 1) * dilation + 1
out = torch.zeros(*g.shape[:-1], L_eff, device=g.device, dtype=g.dtype)
out[..., ::dilation] = g
return out
class MultiStageNeuralWavelet1D(nn.Module):
"""
多层(级联)低/高通滤波器,不降采样。每层一个可学习 FIR 核,可设不同 dilation。
"""
def __init__(self, channels: int, kernel_size: int = 31, stages: int = 3,
dilations=None, share_across_channels: bool = True,
n_fft: int = 512, band_power_p: float = 3.0):
super().__init__()
self.C = channels
self.L = kernel_size
self.S = stages
self.share = share_across_channels
self.n_fft = n_fft
self.p = band_power_p
if dilations is None:
dilations = [1] * stages
assert len(dilations) == stages
self.dilations = dilations
# 每层各一对滤波器核:g0_s(低通), g1_s(高通)
self.g0_list = nn.ParameterList()
self.g1_list = nn.ParameterList()
shape = (1, 1, self.L) if share_across_channels else (self.C, 1, self.L)
for _ in range(stages):
g0 = nn.Parameter(torch.randn(*shape) * 0.05)
g1 = nn.Parameter(torch.randn(*shape) * 0.05)
# 高通零均值启发式初始化
with torch.no_grad():
g1 -= g1.mean(dim=-1, keepdim=True)
self.g0_list.append(g0)
self.g1_list.append(g1)
def _filters_per_stage(self):
"""返回每层的正规化滤波器权重列表(按通道广播 + 轻度范数归一)"""
g0s, g1s = [], []
for s in range(self.S):
g0 = self.g0_list[s]
g1 = self.g1_list[s]
if self.share:
g0 = g0.expand(self.C, 1, self.L)
g1 = g1.expand(self.C, 1, self.L)
# 轻度归一稳定训练
g0 = g0 / (g0.norm(dim=(-2,-1), keepdim=True) + 1e-18)
g1 = g1 / (g1.norm(dim=(-2,-1), keepdim=True) + 1e-18)
g0s.append(g0)
g1s.append(g1)
return g0s, g1s
def forward(self, x):
"""
x: [B, C, T]
级联卷积:低通分支 y_low = g0_S * ... * g0_2 * g0_1 * x
高通分支 y_high = g1_S * ... * g1_2 * g1_1 * x
"""
g0s, g1s = self._filters_per_stage()
y_low, y_high = x, x
for s in range(self.S):
d = self.dilations[s]
pad = (self.L - 1) // 2 * d
y_low = F.conv1d(y_low, g0s[s], padding=pad, dilation=d, groups=self.C)
y_high = F.conv1d(y_high, g1s[s], padding=pad, dilation=d, groups=self.C)
return y_low, y_high
# -------- 频域正则:作用于“总频响”(各层频响相乘) --------
def spectral_losses(self):
"""
返回:
L_low, L_high: 低/高通带形先验(针对总频响)
L_overlap: 频域不重叠(总频响幅度逐点乘)
L_parseval: 近似能量守恒(|G0|^2 + |G1|^2 ≈ 1)
L_shape: 形状约束(每层 g1 零均值、g0 近似对称)
"""
g0s, g1s = self._filters_per_stage()
# 计算每层“插零后”的频响,再逐层相乘得到总频响
G0_total = None
G1_total = None
L_shape_terms = []
for s in range(self.S):
d = self.dilations[s]
# 通道平均成单核看频响(也可逐通道算后再平均)
g0m = g0s[s].mean(0).squeeze(0) # [L]
g1m = g1s[s].mean(0).squeeze(0) # [L]
# 形状约束:高通零均值、低通对称(每层都约束,效果最好)
L_zero_mean_high = g1m.sum().pow(2)
L_sym_low = (g0m - torch.flip(g0m, dims=[-1])).pow(2).mean()
L_shape_terms.append(L_zero_mean_high + L_sym_low)
# 扩展空洞→真实等效冲激响应,再 FFT
g0_eff = dilate_kernel_1d(g0m, d)
g1_eff = dilate_kernel_1d(g1m, d)
_, A0 = rfft_mag(g0_eff, self.n_fft) # [K]
_, A1 = rfft_mag(g1_eff, self.n_fft)
# 逐层频响幅度相乘(总幅度)
G0_total = A0 if G0_total is None else (G0_total * A0)
G1_total = A1 if G1_total is None else (G1_total * A1)
omega = omega_grid_like(G0_total) # [0, π]
w_low = (omega / torch.pi) ** self.p
w_high = ((torch.pi - omega) / torch.pi) ** self.p
L_low = (w_low * (G0_total**2)).mean() # 低通惩罚高频能量
L_high = (w_high * (G1_total**2)).mean() # 高通惩罚低频能量
L_overlap = (G0_total**2 * G1_total**2).mean()
# L_overlap = (G0_total* G1_total).mean()
L_parseval = ((G0_total**2 + G1_total**2 - 1.0)**2).mean()
L_shape = torch.stack(L_shape_terms).mean()
return L_low, L_high, L_overlap, L_parseval, L_shape
|