WaveLSFromer / wavelet.py
ducheng678
Initial WaveLSFromer project
093b0a5
Raw
History Blame Contribute Delete
12.9 kB
import torch
import torch.nn as nn
import torch.nn.functional as F
def rfft_mag(x, n_fft: int):
# x: [..., L](实数)
X = torch.fft.rfft(x, n=n_fft) # [..., K]
return X, X.abs()
def omega_grid_like(spec):
# spec: [..., K]
K = spec.shape[-1]
return torch.linspace(0.0, torch.pi, steps=K, device=spec.device)
class NeuralWavelet1D(nn.Module):
def __init__(self, channels: int, kernel_size: int = 81, share_across_channels: bool = True,
n_fft: int = 256, band_power_p: float = 3.0):
super().__init__()
self.C = channels
self.L = kernel_size
self.share = share_across_channels
self.n_fft = n_fft
self.p = band_power_p
shape = (1, 1, self.L) if share_across_channels else (self.C, 1, self.L)
# 可学习 FIR 核(即冲激响应)
self.g0 = nn.Parameter(torch.randn(*shape) * 0.05) # low-pass
self.g1 = nn.Parameter(torch.randn(*shape) * 0.05) # high-pass
# 可选:初始化成近似低/高通(这里简单零均值/单位和,工程中可用窗化-sinc)
with torch.no_grad():
self.g1 -= self.g1.mean(dim=-1, keepdim=True) # 高通零均值启发式
def _filters(self):
# 按通道广播 + 轻度范数归一,防止早期发散
if self.share:
g0 = self.g0.expand(self.C, 1, self.L)
g1 = self.g1.expand(self.C, 1, self.L)
else:
g0, g1 = self.g0, self.g1
g0 = g0 / (g0.norm(dim=(-2,-1), keepdim=True) + 1e-18)
g1 = g1 / (g1.norm(dim=(-2,-1), keepdim=True) + 1e-18)
return g0, g1
def forward(self, x):
"""
x: [B, C, T]
返回:y_low, y_high(不降采样子带)
"""
g0, g1 = self._filters()
y_low = F.conv1d(x, g0, padding=self.L//2, groups=self.C)
y_high = F.conv1d(x, g1, padding=self.L//2, groups=self.C)
return y_low, y_high
# ---------- 频域正则(数据无关,作用于滤波器频响) ----------
def spectral_losses(self):
"""
返回:
L_low: 低通带形先验(惩罚高频) -> g0
L_high: 高通带形先验(惩罚低频) -> g1
L_overlap: 频域不重叠(|G0| * |G1|)
L_parseval: 框架/能量守恒近似 (|G0|^2 + |G1|^2 ≈ 1)
L_shape: 形状约束(g1零均值、g0对称)
"""
g0, g1 = self._filters() # [C,1,L]
g0m = g0.mean(0).squeeze(0) # [L] 用通道平均代表频响(也可逐通道再平均)
g1m = g1.mean(0).squeeze(0)
G0, A0 = rfft_mag(g0m, self.n_fft) # [K], [K]
G1, A1 = rfft_mag(g1m, self.n_fft)
omega = omega_grid_like(A0) # [K]
# 带形先验权重(单调):低通在高频权重大;高通在低频权重大
w_low = (omega / torch.pi) ** self.p
w_high = ((torch.pi - omega) / torch.pi) ** self.p
L_low = (w_low * (A0**2)).mean() # 低通惩罚高频能量
L_high = (w_high * (A1**2)).mean() # 高通惩罚低频能量
# L_overlap = (A0 * A1).mean() # 频带不重叠(逐点乘)
L_overlap = (A0**2 * A1**2).mean() # 频带不重叠(逐点乘)
L_parseval = ((A0**2 + A1**2 - 1.0)**2).mean()
# 形状约束:g1零均值(高通),g0近似对称(线性相位)
L_zero_mean_high = g1.sum(dim=-1).mean().pow(2)
L_sym_low = (g0 - torch.flip(g0, dims=[-1])).pow(2).mean()
L_shape = L_zero_mean_high + L_sym_low
return L_low, L_high, L_overlap, L_parseval, L_shape
class LowGuidedHighCrossAttn(nn.Module):
def __init__(self, d_model, nhead=4, dropout=0.0, use_gamma=True):
super().__init__()
self.nhead = nhead
self.dk = d_model // nhead
assert d_model % nhead == 0
self.q_proj = nn.Linear(d_model, d_model)
self.k_proj = nn.Linear(d_model, d_model)
self.v_proj = nn.Linear(d_model, d_model)
self.out_proj = nn.Linear(d_model, d_model)
nn.init.zeros_(self.out_proj.weight)
nn.init.zeros_(self.out_proj.bias)
self.dropout = nn.Dropout(dropout)
self.use_gamma = use_gamma
if use_gamma:
self.gamma = nn.Parameter(torch.zeros(1)) # ReZero 风格:初始0更稳
def _split_heads(self, x): # [B,T,D] -> [B,h,T,dk]
B,T,D = x.shape
return x.view(B, T, self.nhead, self.dk).transpose(1,2)
def _merge_heads(self, x): # [B,h,T,dk] -> [B,T,D]
B,h,T,dk = x.shape
return x.transpose(1,2).contiguous().view(B, T, h*dk)
def forward(self, L, H, attn_mask=None, attn_bias=None):
"""
L: 低频主干 [B,T,d]
H: 高频细节 [B,T,d]
返回: out[B,T,d], attn[B,h,T,T]
"""
Q = self._split_heads(self.q_proj(L))
K = self._split_heads(self.k_proj(L)) # 注意:Q/K都从低频来
V = self._split_heads(self.v_proj(H)) # 值来自高频
scores = torch.matmul(Q, K.transpose(-2, -1)) / (self.dk ** 0.5) # [B,h,T,T]
if attn_bias is not None:
scores = scores + attn_bias
if attn_mask is not None:
scores = scores + attn_mask
attn = scores.softmax(dim=-1) # [B,h,T,T]
attn = self.dropout(attn)
Z = torch.matmul(attn, V) # [B,h,T,dk]
Z = self._merge_heads(Z) # [B,T,d]
Z = self.out_proj(Z) # [B,T,d]
out = L + (0.01*torch.sigmoid(self.gamma) * Z if self.use_gamma else Z)
return out, attn
class WaveletFront(nn.Module):
def __init__(self, in_channels=1, d_model=256, kernel_size=31, n_fft=256):
super().__init__()
self.wave = NeuralWavelet1D(in_channels, kernel_size, True, n_fft, band_power_p=3.0)
# self.wave = MultiStageNeuralWavelet1D(in_channels, kernel_size, 1, None, False, n_fft, band_power_p=3.0)
# self.wave = MultiStageNeuralWavelet1D(in_channels, kernel_size, 3, [1,2,4], False, n_fft, band_power_p=3.0)
# self.wave = MultiStageNeuralWavelet1D(in_channels, kernel_size, 1, None, False, n_fft, band_power_p=3.0)
self.proj_low = nn.Conv1d(in_channels, d_model, kernel_size=1)
self.proj_high = nn.Conv1d(in_channels, d_model, kernel_size=1)
self.cross = LowGuidedHighCrossAttn(d_model, nhead=4, dropout=0.1, use_gamma=True)
def forward(self, x, attn_mask=None, attn_bias=None):
"""
x: [B,C,T]
返回:
y: [B,T,d_model] -> 可直接 +PE/+TE 后送入 Transformer
reg: dict 的各项正则损失
attn: cross-attn 权重矩阵(可解释/正则)
"""
y_low, y_high = self.wave(x) # [B,C,T], [B,C,T]
# 特征层归一,防高频主导
y_low = F.layer_norm(y_low, y_low.shape[-1:])
y_high = F.layer_norm(y_high, y_high.shape[-1:])
# 投影到 d_model
L = self.proj_low(y_low).transpose(1,2) # [B,T,d]
H = self.proj_high(y_high).transpose(1,2) # [B,T,d]
# 低→高 Cross-Attn 融合(输出与低频同维度)
Y, attn = self.cross(L, H, attn_mask=attn_mask, attn_bias=attn_bias) # [B,T,d], [B,h,T,T]
# 频域正则项(作用于滤波器)
L_low, L_high, L_overlap, L_parseval, L_shape = self.wave.spectral_losses()
reg = {
"L_low": L_low, "L_high": L_high,
"L_overlap": L_overlap, "L_parseval": L_parseval,
"L_shape": L_shape
}
return Y, reg, attn
def dilate_kernel_1d(g, dilation):
# g: [..., L] -> 插零扩展成 [..., L_eff], L_eff = (L-1)*dilation + 1
L = g.shape[-1]
if dilation == 1:
return g
L_eff = (L - 1) * dilation + 1
out = torch.zeros(*g.shape[:-1], L_eff, device=g.device, dtype=g.dtype)
out[..., ::dilation] = g
return out
class MultiStageNeuralWavelet1D(nn.Module):
"""
多层(级联)低/高通滤波器,不降采样。每层一个可学习 FIR 核,可设不同 dilation。
"""
def __init__(self, channels: int, kernel_size: int = 31, stages: int = 3,
dilations=None, share_across_channels: bool = True,
n_fft: int = 512, band_power_p: float = 3.0):
super().__init__()
self.C = channels
self.L = kernel_size
self.S = stages
self.share = share_across_channels
self.n_fft = n_fft
self.p = band_power_p
if dilations is None:
dilations = [1] * stages
assert len(dilations) == stages
self.dilations = dilations
# 每层各一对滤波器核:g0_s(低通), g1_s(高通)
self.g0_list = nn.ParameterList()
self.g1_list = nn.ParameterList()
shape = (1, 1, self.L) if share_across_channels else (self.C, 1, self.L)
for _ in range(stages):
g0 = nn.Parameter(torch.randn(*shape) * 0.05)
g1 = nn.Parameter(torch.randn(*shape) * 0.05)
# 高通零均值启发式初始化
with torch.no_grad():
g1 -= g1.mean(dim=-1, keepdim=True)
self.g0_list.append(g0)
self.g1_list.append(g1)
def _filters_per_stage(self):
"""返回每层的正规化滤波器权重列表(按通道广播 + 轻度范数归一)"""
g0s, g1s = [], []
for s in range(self.S):
g0 = self.g0_list[s]
g1 = self.g1_list[s]
if self.share:
g0 = g0.expand(self.C, 1, self.L)
g1 = g1.expand(self.C, 1, self.L)
# 轻度归一稳定训练
g0 = g0 / (g0.norm(dim=(-2,-1), keepdim=True) + 1e-18)
g1 = g1 / (g1.norm(dim=(-2,-1), keepdim=True) + 1e-18)
g0s.append(g0)
g1s.append(g1)
return g0s, g1s
def forward(self, x):
"""
x: [B, C, T]
级联卷积:低通分支 y_low = g0_S * ... * g0_2 * g0_1 * x
高通分支 y_high = g1_S * ... * g1_2 * g1_1 * x
"""
g0s, g1s = self._filters_per_stage()
y_low, y_high = x, x
for s in range(self.S):
d = self.dilations[s]
pad = (self.L - 1) // 2 * d
y_low = F.conv1d(y_low, g0s[s], padding=pad, dilation=d, groups=self.C)
y_high = F.conv1d(y_high, g1s[s], padding=pad, dilation=d, groups=self.C)
return y_low, y_high
# -------- 频域正则:作用于“总频响”(各层频响相乘) --------
def spectral_losses(self):
"""
返回:
L_low, L_high: 低/高通带形先验(针对总频响)
L_overlap: 频域不重叠(总频响幅度逐点乘)
L_parseval: 近似能量守恒(|G0|^2 + |G1|^2 ≈ 1)
L_shape: 形状约束(每层 g1 零均值、g0 近似对称)
"""
g0s, g1s = self._filters_per_stage()
# 计算每层“插零后”的频响,再逐层相乘得到总频响
G0_total = None
G1_total = None
L_shape_terms = []
for s in range(self.S):
d = self.dilations[s]
# 通道平均成单核看频响(也可逐通道算后再平均)
g0m = g0s[s].mean(0).squeeze(0) # [L]
g1m = g1s[s].mean(0).squeeze(0) # [L]
# 形状约束:高通零均值、低通对称(每层都约束,效果最好)
L_zero_mean_high = g1m.sum().pow(2)
L_sym_low = (g0m - torch.flip(g0m, dims=[-1])).pow(2).mean()
L_shape_terms.append(L_zero_mean_high + L_sym_low)
# 扩展空洞→真实等效冲激响应,再 FFT
g0_eff = dilate_kernel_1d(g0m, d)
g1_eff = dilate_kernel_1d(g1m, d)
_, A0 = rfft_mag(g0_eff, self.n_fft) # [K]
_, A1 = rfft_mag(g1_eff, self.n_fft)
# 逐层频响幅度相乘(总幅度)
G0_total = A0 if G0_total is None else (G0_total * A0)
G1_total = A1 if G1_total is None else (G1_total * A1)
omega = omega_grid_like(G0_total) # [0, π]
w_low = (omega / torch.pi) ** self.p
w_high = ((torch.pi - omega) / torch.pi) ** self.p
L_low = (w_low * (G0_total**2)).mean() # 低通惩罚高频能量
L_high = (w_high * (G1_total**2)).mean() # 高通惩罚低频能量
L_overlap = (G0_total**2 * G1_total**2).mean()
# L_overlap = (G0_total* G1_total).mean()
L_parseval = ((G0_total**2 + G1_total**2 - 1.0)**2).mean()
L_shape = torch.stack(L_shape_terms).mean()
return L_low, L_high, L_overlap, L_parseval, L_shape