File size: 12,854 Bytes
093b0a5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
import torch
import torch.nn as nn
import torch.nn.functional as F


def rfft_mag(x, n_fft: int):
    # x: [..., L](实数)
    X = torch.fft.rfft(x, n=n_fft)             # [..., K]
    return X, X.abs()

def omega_grid_like(spec):
    # spec: [..., K]
    K = spec.shape[-1]
    return torch.linspace(0.0, torch.pi, steps=K, device=spec.device)

class NeuralWavelet1D(nn.Module):
    def __init__(self, channels: int, kernel_size: int = 81, share_across_channels: bool = True,
                 n_fft: int = 256, band_power_p: float = 3.0):
        super().__init__()
        self.C = channels
        self.L = kernel_size
        self.share = share_across_channels
        self.n_fft = n_fft
        self.p = band_power_p

        shape = (1, 1, self.L) if share_across_channels else (self.C, 1, self.L)
        # 可学习 FIR 核(即冲激响应)
        self.g0 = nn.Parameter(torch.randn(*shape) * 0.05)  # low-pass
        self.g1 = nn.Parameter(torch.randn(*shape) * 0.05)  # high-pass

        # 可选:初始化成近似低/高通(这里简单零均值/单位和,工程中可用窗化-sinc)
        with torch.no_grad():
            self.g1 -= self.g1.mean(dim=-1, keepdim=True)   # 高通零均值启发式

    def _filters(self):
        # 按通道广播 + 轻度范数归一,防止早期发散
        if self.share:
            g0 = self.g0.expand(self.C, 1, self.L)
            g1 = self.g1.expand(self.C, 1, self.L)
        else:
            g0, g1 = self.g0, self.g1
        g0 = g0 / (g0.norm(dim=(-2,-1), keepdim=True) + 1e-18)
        g1 = g1 / (g1.norm(dim=(-2,-1), keepdim=True) + 1e-18)
        return g0, g1

    def forward(self, x):
        """
        x: [B, C, T]
        返回:y_low, y_high(不降采样子带)
        """
        g0, g1 = self._filters()
        y_low  = F.conv1d(x, g0, padding=self.L//2, groups=self.C)
        y_high = F.conv1d(x, g1, padding=self.L//2, groups=self.C)
        return y_low, y_high

    # ---------- 频域正则(数据无关,作用于滤波器频响) ----------
    def spectral_losses(self):
        """
        返回:
          L_low:   低通带形先验(惩罚高频)      -> g0
          L_high:  高通带形先验(惩罚低频)      -> g1
          L_overlap: 频域不重叠(|G0| * |G1|)
          L_parseval: 框架/能量守恒近似 (|G0|^2 + |G1|^2 ≈ 1)
          L_shape: 形状约束(g1零均值、g0对称)
        """
        g0, g1 = self._filters()                  # [C,1,L]
        g0m = g0.mean(0).squeeze(0)               # [L]  用通道平均代表频响(也可逐通道再平均)
        g1m = g1.mean(0).squeeze(0)

        G0, A0 = rfft_mag(g0m, self.n_fft)        # [K], [K]
        G1, A1 = rfft_mag(g1m, self.n_fft)

        omega = omega_grid_like(A0)               # [K]
        # 带形先验权重(单调):低通在高频权重大;高通在低频权重大
        w_low  = (omega / torch.pi) ** self.p
        w_high = ((torch.pi - omega) / torch.pi) ** self.p

        L_low   = (w_low  * (A0**2)).mean()       # 低通惩罚高频能量
        L_high  = (w_high * (A1**2)).mean()       # 高通惩罚低频能量
        #  L_overlap  = (A0 * A1).mean()             # 频带不重叠(逐点乘)
        L_overlap  = (A0**2 * A1**2).mean()             # 频带不重叠(逐点乘)
        L_parseval = ((A0**2 + A1**2 - 1.0)**2).mean()

        # 形状约束:g1零均值(高通),g0近似对称(线性相位)
        L_zero_mean_high = g1.sum(dim=-1).mean().pow(2)
        L_sym_low = (g0 - torch.flip(g0, dims=[-1])).pow(2).mean()
        L_shape = L_zero_mean_high + L_sym_low

        return L_low, L_high, L_overlap, L_parseval, L_shape

class LowGuidedHighCrossAttn(nn.Module):
    def __init__(self, d_model, nhead=4, dropout=0.0, use_gamma=True):
        super().__init__()
        self.nhead = nhead
        self.dk = d_model // nhead
        assert d_model % nhead == 0

        self.q_proj = nn.Linear(d_model, d_model)
        self.k_proj = nn.Linear(d_model, d_model)
        self.v_proj = nn.Linear(d_model, d_model)
        self.out_proj = nn.Linear(d_model, d_model)
        nn.init.zeros_(self.out_proj.weight)
        nn.init.zeros_(self.out_proj.bias)

        self.dropout = nn.Dropout(dropout)
        self.use_gamma = use_gamma
        if use_gamma:
            self.gamma = nn.Parameter(torch.zeros(1))  # ReZero 风格:初始0更稳

    def _split_heads(self, x):  # [B,T,D] -> [B,h,T,dk]
        B,T,D = x.shape
        return x.view(B, T, self.nhead, self.dk).transpose(1,2)

    def _merge_heads(self, x):  # [B,h,T,dk] -> [B,T,D]
        B,h,T,dk = x.shape
        return x.transpose(1,2).contiguous().view(B, T, h*dk)

    def forward(self, L, H, attn_mask=None, attn_bias=None):
        """
        L: 低频主干 [B,T,d]
        H: 高频细节 [B,T,d]
        返回: out[B,T,d], attn[B,h,T,T]
        """
        Q = self._split_heads(self.q_proj(L))
        K = self._split_heads(self.k_proj(L))  # 注意:Q/K都从低频来
        V = self._split_heads(self.v_proj(H))  # 值来自高频

        scores = torch.matmul(Q, K.transpose(-2, -1)) / (self.dk ** 0.5)  # [B,h,T,T]
        if attn_bias is not None:
            scores = scores + attn_bias
        if attn_mask is not None:
            scores = scores + attn_mask

        attn = scores.softmax(dim=-1)          # [B,h,T,T]
        attn = self.dropout(attn)

        Z = torch.matmul(attn, V)              # [B,h,T,dk]
        Z = self._merge_heads(Z)               # [B,T,d]
        Z = self.out_proj(Z)                   # [B,T,d]

        out = L + (0.01*torch.sigmoid(self.gamma) * Z if self.use_gamma else Z)
        return out, attn

class WaveletFront(nn.Module):
    def __init__(self, in_channels=1, d_model=256, kernel_size=31, n_fft=256):
        super().__init__()
        self.wave = NeuralWavelet1D(in_channels, kernel_size, True, n_fft, band_power_p=3.0)
        #  self.wave = MultiStageNeuralWavelet1D(in_channels, kernel_size, 1, None, False, n_fft, band_power_p=3.0)
        #  self.wave = MultiStageNeuralWavelet1D(in_channels, kernel_size, 3, [1,2,4], False, n_fft, band_power_p=3.0)
        #  self.wave = MultiStageNeuralWavelet1D(in_channels, kernel_size, 1, None, False, n_fft, band_power_p=3.0)
        self.proj_low  = nn.Conv1d(in_channels, d_model, kernel_size=1)
        self.proj_high = nn.Conv1d(in_channels, d_model, kernel_size=1)
        self.cross = LowGuidedHighCrossAttn(d_model, nhead=4, dropout=0.1, use_gamma=True)

    def forward(self, x, attn_mask=None, attn_bias=None):
        """
        x: [B,C,T]
        返回:
          y:  [B,T,d_model]  -> 可直接 +PE/+TE 后送入 Transformer
          reg: dict 的各项正则损失
          attn: cross-attn 权重矩阵(可解释/正则)
        """
        y_low, y_high = self.wave(x)                    # [B,C,T], [B,C,T]
        # 特征层归一,防高频主导
        y_low  = F.layer_norm(y_low,  y_low.shape[-1:])
        y_high = F.layer_norm(y_high, y_high.shape[-1:])

        # 投影到 d_model
        L = self.proj_low(y_low).transpose(1,2)         # [B,T,d]
        H = self.proj_high(y_high).transpose(1,2)       # [B,T,d]

        # 低→高 Cross-Attn 融合(输出与低频同维度)
        Y, attn = self.cross(L, H, attn_mask=attn_mask, attn_bias=attn_bias)  # [B,T,d], [B,h,T,T]

        # 频域正则项(作用于滤波器)
        L_low, L_high, L_overlap, L_parseval, L_shape = self.wave.spectral_losses()
        reg = {
            "L_low": L_low, "L_high": L_high,
            "L_overlap": L_overlap, "L_parseval": L_parseval,
            "L_shape": L_shape
        }
        return Y, reg, attn

def dilate_kernel_1d(g, dilation):
    # g: [..., L]  -> 插零扩展成 [..., L_eff], L_eff = (L-1)*dilation + 1
    L = g.shape[-1]
    if dilation == 1:
        return g
    L_eff = (L - 1) * dilation + 1
    out = torch.zeros(*g.shape[:-1], L_eff, device=g.device, dtype=g.dtype)
    out[..., ::dilation] = g
    return out

class MultiStageNeuralWavelet1D(nn.Module):
    """
    多层(级联)低/高通滤波器,不降采样。每层一个可学习 FIR 核,可设不同 dilation。
    """
    def __init__(self, channels: int, kernel_size: int = 31, stages: int = 3,
                 dilations=None, share_across_channels: bool = True,
                 n_fft: int = 512, band_power_p: float = 3.0):
        super().__init__()
        self.C = channels
        self.L = kernel_size
        self.S = stages
        self.share = share_across_channels
        self.n_fft = n_fft
        self.p = band_power_p

        if dilations is None:
            dilations = [1] * stages
        assert len(dilations) == stages
        self.dilations = dilations

        # 每层各一对滤波器核:g0_s(低通), g1_s(高通)
        self.g0_list = nn.ParameterList()
        self.g1_list = nn.ParameterList()
        shape = (1, 1, self.L) if share_across_channels else (self.C, 1, self.L)

        for _ in range(stages):
            g0 = nn.Parameter(torch.randn(*shape) * 0.05)
            g1 = nn.Parameter(torch.randn(*shape) * 0.05)
            # 高通零均值启发式初始化
            with torch.no_grad():
                g1 -= g1.mean(dim=-1, keepdim=True)
            self.g0_list.append(g0)
            self.g1_list.append(g1)

    def _filters_per_stage(self):
        """返回每层的正规化滤波器权重列表(按通道广播 + 轻度范数归一)"""
        g0s, g1s = [], []
        for s in range(self.S):
            g0 = self.g0_list[s]
            g1 = self.g1_list[s]
            if self.share:
                g0 = g0.expand(self.C, 1, self.L)
                g1 = g1.expand(self.C, 1, self.L)
            # 轻度归一稳定训练
            g0 = g0 / (g0.norm(dim=(-2,-1), keepdim=True) + 1e-18)
            g1 = g1 / (g1.norm(dim=(-2,-1), keepdim=True) + 1e-18)
            g0s.append(g0)
            g1s.append(g1)
        return g0s, g1s

    def forward(self, x):
        """
        x: [B, C, T]
        级联卷积:低通分支 y_low = g0_S * ... * g0_2 * g0_1 * x
                  高通分支 y_high = g1_S * ... * g1_2 * g1_1 * x
        """
        g0s, g1s = self._filters_per_stage()
        y_low, y_high = x, x
        for s in range(self.S):
            d = self.dilations[s]
            pad = (self.L - 1) // 2 * d
            y_low  = F.conv1d(y_low,  g0s[s], padding=pad, dilation=d, groups=self.C)
            y_high = F.conv1d(y_high, g1s[s], padding=pad, dilation=d, groups=self.C)
        return y_low, y_high

    # -------- 频域正则:作用于“总频响”(各层频响相乘) --------
    def spectral_losses(self):
        """
        返回:
          L_low, L_high: 低/高通带形先验(针对总频响)
          L_overlap: 频域不重叠(总频响幅度逐点乘)
          L_parseval: 近似能量守恒(|G0|^2 + |G1|^2 ≈ 1)
          L_shape: 形状约束(每层 g1 零均值、g0 近似对称)
        """
        g0s, g1s = self._filters_per_stage()

        # 计算每层“插零后”的频响,再逐层相乘得到总频响
        G0_total = None
        G1_total = None
        L_shape_terms = []

        for s in range(self.S):
            d = self.dilations[s]
            # 通道平均成单核看频响(也可逐通道算后再平均)
            g0m = g0s[s].mean(0).squeeze(0)  # [L]
            g1m = g1s[s].mean(0).squeeze(0)  # [L]

            # 形状约束:高通零均值、低通对称(每层都约束,效果最好)
            L_zero_mean_high = g1m.sum().pow(2)
            L_sym_low = (g0m - torch.flip(g0m, dims=[-1])).pow(2).mean()
            L_shape_terms.append(L_zero_mean_high + L_sym_low)

            # 扩展空洞→真实等效冲激响应,再 FFT
            g0_eff = dilate_kernel_1d(g0m, d)
            g1_eff = dilate_kernel_1d(g1m, d)
            _, A0 = rfft_mag(g0_eff, self.n_fft)   # [K]
            _, A1 = rfft_mag(g1_eff, self.n_fft)

            # 逐层频响幅度相乘(总幅度)
            G0_total = A0 if G0_total is None else (G0_total * A0)
            G1_total = A1 if G1_total is None else (G1_total * A1)

        omega = omega_grid_like(G0_total)          # [0, π]
        w_low  = (omega / torch.pi) ** self.p
        w_high = ((torch.pi - omega) / torch.pi) ** self.p

        L_low   = (w_low  * (G0_total**2)).mean()       # 低通惩罚高频能量
        L_high  = (w_high * (G1_total**2)).mean()       # 高通惩罚低频能量
        L_overlap  = (G0_total**2 * G1_total**2).mean()
        #  L_overlap  = (G0_total* G1_total).mean()
        L_parseval = ((G0_total**2 + G1_total**2 - 1.0)**2).mean()
        L_shape = torch.stack(L_shape_terms).mean()

        return L_low, L_high, L_overlap, L_parseval, L_shape