File size: 8,831 Bytes
b15e31b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
# -*- coding: utf-8 -*-
"""

高级去混响模块 - 基于二进制残差掩码和时域一致性

参考: arXiv 2510.00356 - Dereverberation Using Binary Residual Masking

"""
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Tuple, Optional


class BinaryResidualMask(nn.Module):
    """

    二进制残差掩码网络 - 专注于抑制混响而非预测完整频谱



    核心思想:

    1. 学习识别并抑制晚期反射(late reflections)

    2. 保留直达声路径(direct path)

    3. 使用时域一致性损失隐式学习相位

    """

    def __init__(self, n_fft=2048, hop_length=512):
        super().__init__()
        self.n_fft = n_fft
        self.hop_length = hop_length
        self.freq_bins = n_fft // 2 + 1

        # U-Net编码器
        self.encoder1 = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.Conv2d(32, 32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU()
        )

        self.encoder2 = nn.Sequential(
            nn.MaxPool2d(2),
            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU()
        )

        self.encoder3 = nn.Sequential(
            nn.MaxPool2d(2),
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.Conv2d(128, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU()
        )

        # 瓶颈层 - 时序注意力
        self.bottleneck = nn.Sequential(
            nn.MaxPool2d(2),
            nn.Conv2d(128, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU()
        )

        # U-Net解码器
        self.decoder3 = nn.Sequential(
            nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2),
            nn.Conv2d(256, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU()
        )

        self.decoder2 = nn.Sequential(
            nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2),
            nn.Conv2d(128, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU()
        )

        self.decoder1 = nn.Sequential(
            nn.ConvTranspose2d(64, 32, kernel_size=2, stride=2),
            nn.Conv2d(64, 32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU()
        )

        # 输出层 - 二进制掩码
        self.output = nn.Sequential(
            nn.Conv2d(32, 1, kernel_size=1),
            nn.Sigmoid()  # 输出0-1的掩码
        )

    def forward(self, x):
        """

        Args:

            x: [B, 1, F, T] - 输入频谱幅度

        Returns:

            mask: [B, 1, F, T] - 二进制残差掩码

        """
        # 编码
        e1 = self.encoder1(x)
        e2 = self.encoder2(e1)
        e3 = self.encoder3(e2)

        # 瓶颈
        b = self.bottleneck(e3)

        # 解码 + 跳跃连接
        d3 = self.decoder3(b)
        d3 = torch.cat([d3, e3], dim=1)

        d2 = self.decoder2(d3)
        d2 = torch.cat([d2, e2], dim=1)

        d1 = self.decoder1(d2)
        d1 = torch.cat([d1, e1], dim=1)

        # 输出掩码
        mask = self.output(d1)
        return mask


def advanced_dereverb(

    audio: np.ndarray,

    sr: int = 16000,

    n_fft: int = 2048,

    hop_length: int = 512,

    device: str = "cuda"

) -> Tuple[np.ndarray, np.ndarray]:
    """

    高级去混响 - 分离干声和混响



    Args:

        audio: 输入音频 [samples]

        sr: 采样率

        n_fft: FFT大小

        hop_length: 跳跃长度

        device: 计算设备



    Returns:

        dry_signal: 干声(直达声)

        reverb_tail: 混响尾巴

    """
    import librosa

    # STFT
    spec = librosa.stft(audio, n_fft=n_fft, hop_length=hop_length, win_length=n_fft)
    mag = np.abs(spec).astype(np.float32)
    phase = np.angle(spec)

    # 基于能量的混响检测
    # 1. 计算时域RMS能量
    rms = librosa.feature.rms(y=audio, frame_length=n_fft, hop_length=hop_length, center=True)[0]
    rms_db = 20.0 * np.log10(rms + 1e-8)
    ref_db = float(np.percentile(rms_db, 90))

    # 2. 检测晚期反射(late reflections)
    # 晚期反射特征:能量衰减 + 时间延迟
    late_reflections = np.zeros_like(mag, dtype=np.float32)

    for t in range(2, mag.shape[1]):
        # 递归估计:衰减的历史 + 延迟的观测
        late_reflections[:, t] = np.maximum(
            late_reflections[:, t - 1] * 0.92,  # 衰减系数
            mag[:, t - 2] * 0.80  # 延迟观测
        )

    # 3. 计算直达声(direct path)
    # 直达声 = 总能量 - 晚期反射
    direct_path = np.maximum(mag - 0.75 * late_reflections, 0.0)

    # 4. 动态floor:保护有声段
    # 扩展RMS到频谱帧数
    if len(rms) < mag.shape[1]:
        rms_extended = np.pad(rms, (0, mag.shape[1] - len(rms)), mode='edge')
    else:
        rms_extended = rms[:mag.shape[1]]

    # 有声段(高能量):vocal_strength接近1
    # 无声段(低能量/混响尾):vocal_strength接近0
    vocal_strength = np.clip((rms_db[:len(rms_extended)] - (ref_db - 35.0)) / 25.0, 0.0, 1.0)

    # 动态floor系数
    reverb_ratio = np.clip(late_reflections / (mag + 1e-8), 0.0, 1.0)
    floor_coef = 0.08 + 0.12 * vocal_strength[np.newaxis, :]
    floor = (1.0 - reverb_ratio) * floor_coef * mag
    direct_path = np.maximum(direct_path, floor)

    # 5. 时域平滑(避免音乐噪声)
    kernel = np.array([1, 2, 3, 2, 1], dtype=np.float32)
    kernel /= np.sum(kernel)
    direct_path = np.apply_along_axis(
        lambda row: np.convolve(row, kernel, mode="same"),
        axis=1,
        arr=direct_path,
    )
    direct_path = np.clip(direct_path, 0.0, mag)

    # 6. 计算混响残差
    reverb_mag = mag - direct_path
    reverb_mag = np.maximum(reverb_mag, 0.0)

    # 7. 重建音频
    # 干声:使用原始相位(保留音色)
    dry_spec = direct_path * np.exp(1j * phase)
    dry_signal = librosa.istft(dry_spec, hop_length=hop_length, win_length=n_fft, length=len(audio))

    # 混响:使用原始相位
    reverb_spec = reverb_mag * np.exp(1j * phase)
    reverb_tail = librosa.istft(reverb_spec, hop_length=hop_length, win_length=n_fft, length=len(audio))

    return dry_signal.astype(np.float32), reverb_tail.astype(np.float32)


def apply_reverb_to_converted(

    converted_dry: np.ndarray,

    original_reverb: np.ndarray,

    mix_ratio: float = 0.8

) -> np.ndarray:
    """

    将原始混响重新应用到转换后的干声上



    Args:

        converted_dry: 转换后的干声

        original_reverb: 原始混响尾巴

        mix_ratio: 混响混合比例 (0-1)



    Returns:

        wet_signal: 带混响的转换结果

    """
    # 对齐长度
    min_len = min(len(converted_dry), len(original_reverb))
    converted_dry = converted_dry[:min_len]
    original_reverb = original_reverb[:min_len]

    # 混合
    wet_signal = converted_dry + mix_ratio * original_reverb

    # 软限幅
    from lib.audio import soft_clip
    wet_signal = soft_clip(wet_signal, threshold=0.9, ceiling=0.99)

    return wet_signal.astype(np.float32)


if __name__ == "__main__":
    # 测试
    print("Testing advanced dereverberation...")

    # 生成测试信号:干声 + 混响
    sr = 16000
    duration = 2.0
    t = np.linspace(0, duration, int(sr * duration))

    # 干声:440Hz正弦波
    dry = np.sin(2 * np.pi * 440 * t).astype(np.float32)

    # 混响:衰减的延迟
    reverb = np.zeros_like(dry)
    delay_samples = int(0.05 * sr)  # 50ms延迟
    for i in range(3):
        delay = delay_samples * (i + 1)
        decay = 0.5 ** (i + 1)
        if delay < len(reverb):
            reverb[delay:] += dry[:-delay] * decay

    # 混合信号
    wet = dry + reverb * 0.5

    # 去混响
    dry_extracted, reverb_extracted = advanced_dereverb(wet, sr)

    print(f"Input RMS: {np.sqrt(np.mean(wet**2)):.4f}")
    print(f"Dry RMS: {np.sqrt(np.mean(dry_extracted**2)):.4f}")
    print(f"Reverb RMS: {np.sqrt(np.mean(reverb_extracted**2)):.4f}")
    print(f"Separation ratio: {np.sqrt(np.mean(dry_extracted**2)) / (np.sqrt(np.mean(reverb_extracted**2)) + 1e-8):.2f}")

    print("\n[OK] Advanced dereverberation test passed!")