File size: 8,831 Bytes
b15e31b | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 | # -*- coding: utf-8 -*-
"""
高级去混响模块 - 基于二进制残差掩码和时域一致性
参考: arXiv 2510.00356 - Dereverberation Using Binary Residual Masking
"""
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Tuple, Optional
class BinaryResidualMask(nn.Module):
"""
二进制残差掩码网络 - 专注于抑制混响而非预测完整频谱
核心思想:
1. 学习识别并抑制晚期反射(late reflections)
2. 保留直达声路径(direct path)
3. 使用时域一致性损失隐式学习相位
"""
def __init__(self, n_fft=2048, hop_length=512):
super().__init__()
self.n_fft = n_fft
self.hop_length = hop_length
self.freq_bins = n_fft // 2 + 1
# U-Net编码器
self.encoder1 = nn.Sequential(
nn.Conv2d(1, 32, kernel_size=3, padding=1),
nn.BatchNorm2d(32),
nn.ReLU(),
nn.Conv2d(32, 32, kernel_size=3, padding=1),
nn.BatchNorm2d(32),
nn.ReLU()
)
self.encoder2 = nn.Sequential(
nn.MaxPool2d(2),
nn.Conv2d(32, 64, kernel_size=3, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(),
nn.Conv2d(64, 64, kernel_size=3, padding=1),
nn.BatchNorm2d(64),
nn.ReLU()
)
self.encoder3 = nn.Sequential(
nn.MaxPool2d(2),
nn.Conv2d(64, 128, kernel_size=3, padding=1),
nn.BatchNorm2d(128),
nn.ReLU(),
nn.Conv2d(128, 128, kernel_size=3, padding=1),
nn.BatchNorm2d(128),
nn.ReLU()
)
# 瓶颈层 - 时序注意力
self.bottleneck = nn.Sequential(
nn.MaxPool2d(2),
nn.Conv2d(128, 256, kernel_size=3, padding=1),
nn.BatchNorm2d(256),
nn.ReLU()
)
# U-Net解码器
self.decoder3 = nn.Sequential(
nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2),
nn.Conv2d(256, 128, kernel_size=3, padding=1),
nn.BatchNorm2d(128),
nn.ReLU()
)
self.decoder2 = nn.Sequential(
nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2),
nn.Conv2d(128, 64, kernel_size=3, padding=1),
nn.BatchNorm2d(64),
nn.ReLU()
)
self.decoder1 = nn.Sequential(
nn.ConvTranspose2d(64, 32, kernel_size=2, stride=2),
nn.Conv2d(64, 32, kernel_size=3, padding=1),
nn.BatchNorm2d(32),
nn.ReLU()
)
# 输出层 - 二进制掩码
self.output = nn.Sequential(
nn.Conv2d(32, 1, kernel_size=1),
nn.Sigmoid() # 输出0-1的掩码
)
def forward(self, x):
"""
Args:
x: [B, 1, F, T] - 输入频谱幅度
Returns:
mask: [B, 1, F, T] - 二进制残差掩码
"""
# 编码
e1 = self.encoder1(x)
e2 = self.encoder2(e1)
e3 = self.encoder3(e2)
# 瓶颈
b = self.bottleneck(e3)
# 解码 + 跳跃连接
d3 = self.decoder3(b)
d3 = torch.cat([d3, e3], dim=1)
d2 = self.decoder2(d3)
d2 = torch.cat([d2, e2], dim=1)
d1 = self.decoder1(d2)
d1 = torch.cat([d1, e1], dim=1)
# 输出掩码
mask = self.output(d1)
return mask
def advanced_dereverb(
audio: np.ndarray,
sr: int = 16000,
n_fft: int = 2048,
hop_length: int = 512,
device: str = "cuda"
) -> Tuple[np.ndarray, np.ndarray]:
"""
高级去混响 - 分离干声和混响
Args:
audio: 输入音频 [samples]
sr: 采样率
n_fft: FFT大小
hop_length: 跳跃长度
device: 计算设备
Returns:
dry_signal: 干声(直达声)
reverb_tail: 混响尾巴
"""
import librosa
# STFT
spec = librosa.stft(audio, n_fft=n_fft, hop_length=hop_length, win_length=n_fft)
mag = np.abs(spec).astype(np.float32)
phase = np.angle(spec)
# 基于能量的混响检测
# 1. 计算时域RMS能量
rms = librosa.feature.rms(y=audio, frame_length=n_fft, hop_length=hop_length, center=True)[0]
rms_db = 20.0 * np.log10(rms + 1e-8)
ref_db = float(np.percentile(rms_db, 90))
# 2. 检测晚期反射(late reflections)
# 晚期反射特征:能量衰减 + 时间延迟
late_reflections = np.zeros_like(mag, dtype=np.float32)
for t in range(2, mag.shape[1]):
# 递归估计:衰减的历史 + 延迟的观测
late_reflections[:, t] = np.maximum(
late_reflections[:, t - 1] * 0.92, # 衰减系数
mag[:, t - 2] * 0.80 # 延迟观测
)
# 3. 计算直达声(direct path)
# 直达声 = 总能量 - 晚期反射
direct_path = np.maximum(mag - 0.75 * late_reflections, 0.0)
# 4. 动态floor:保护有声段
# 扩展RMS到频谱帧数
if len(rms) < mag.shape[1]:
rms_extended = np.pad(rms, (0, mag.shape[1] - len(rms)), mode='edge')
else:
rms_extended = rms[:mag.shape[1]]
# 有声段(高能量):vocal_strength接近1
# 无声段(低能量/混响尾):vocal_strength接近0
vocal_strength = np.clip((rms_db[:len(rms_extended)] - (ref_db - 35.0)) / 25.0, 0.0, 1.0)
# 动态floor系数
reverb_ratio = np.clip(late_reflections / (mag + 1e-8), 0.0, 1.0)
floor_coef = 0.08 + 0.12 * vocal_strength[np.newaxis, :]
floor = (1.0 - reverb_ratio) * floor_coef * mag
direct_path = np.maximum(direct_path, floor)
# 5. 时域平滑(避免音乐噪声)
kernel = np.array([1, 2, 3, 2, 1], dtype=np.float32)
kernel /= np.sum(kernel)
direct_path = np.apply_along_axis(
lambda row: np.convolve(row, kernel, mode="same"),
axis=1,
arr=direct_path,
)
direct_path = np.clip(direct_path, 0.0, mag)
# 6. 计算混响残差
reverb_mag = mag - direct_path
reverb_mag = np.maximum(reverb_mag, 0.0)
# 7. 重建音频
# 干声:使用原始相位(保留音色)
dry_spec = direct_path * np.exp(1j * phase)
dry_signal = librosa.istft(dry_spec, hop_length=hop_length, win_length=n_fft, length=len(audio))
# 混响:使用原始相位
reverb_spec = reverb_mag * np.exp(1j * phase)
reverb_tail = librosa.istft(reverb_spec, hop_length=hop_length, win_length=n_fft, length=len(audio))
return dry_signal.astype(np.float32), reverb_tail.astype(np.float32)
def apply_reverb_to_converted(
converted_dry: np.ndarray,
original_reverb: np.ndarray,
mix_ratio: float = 0.8
) -> np.ndarray:
"""
将原始混响重新应用到转换后的干声上
Args:
converted_dry: 转换后的干声
original_reverb: 原始混响尾巴
mix_ratio: 混响混合比例 (0-1)
Returns:
wet_signal: 带混响的转换结果
"""
# 对齐长度
min_len = min(len(converted_dry), len(original_reverb))
converted_dry = converted_dry[:min_len]
original_reverb = original_reverb[:min_len]
# 混合
wet_signal = converted_dry + mix_ratio * original_reverb
# 软限幅
from lib.audio import soft_clip
wet_signal = soft_clip(wet_signal, threshold=0.9, ceiling=0.99)
return wet_signal.astype(np.float32)
if __name__ == "__main__":
# 测试
print("Testing advanced dereverberation...")
# 生成测试信号:干声 + 混响
sr = 16000
duration = 2.0
t = np.linspace(0, duration, int(sr * duration))
# 干声:440Hz正弦波
dry = np.sin(2 * np.pi * 440 * t).astype(np.float32)
# 混响:衰减的延迟
reverb = np.zeros_like(dry)
delay_samples = int(0.05 * sr) # 50ms延迟
for i in range(3):
delay = delay_samples * (i + 1)
decay = 0.5 ** (i + 1)
if delay < len(reverb):
reverb[delay:] += dry[:-delay] * decay
# 混合信号
wet = dry + reverb * 0.5
# 去混响
dry_extracted, reverb_extracted = advanced_dereverb(wet, sr)
print(f"Input RMS: {np.sqrt(np.mean(wet**2)):.4f}")
print(f"Dry RMS: {np.sqrt(np.mean(dry_extracted**2)):.4f}")
print(f"Reverb RMS: {np.sqrt(np.mean(reverb_extracted**2)):.4f}")
print(f"Separation ratio: {np.sqrt(np.mean(dry_extracted**2)) / (np.sqrt(np.mean(reverb_extracted**2)) + 1e-8):.2f}")
print("\n[OK] Advanced dereverberation test passed!")
|