| |
|
| | """
|
| | 高级去混响模块 - 基于二进制残差掩码和时域一致性
|
| | 参考: arXiv 2510.00356 - Dereverberation Using Binary Residual Masking
|
| | """
|
| | import numpy as np
|
| | import torch
|
| | import torch.nn as nn
|
| | import torch.nn.functional as F
|
| | from typing import Tuple, Optional
|
| |
|
| |
|
| | class BinaryResidualMask(nn.Module):
|
| | """
|
| | 二进制残差掩码网络 - 专注于抑制混响而非预测完整频谱
|
| |
|
| | 核心思想:
|
| | 1. 学习识别并抑制晚期反射(late reflections)
|
| | 2. 保留直达声路径(direct path)
|
| | 3. 使用时域一致性损失隐式学习相位
|
| | """
|
| |
|
| | def __init__(self, n_fft=2048, hop_length=512):
|
| | super().__init__()
|
| | self.n_fft = n_fft
|
| | self.hop_length = hop_length
|
| | self.freq_bins = n_fft // 2 + 1
|
| |
|
| |
|
| | self.encoder1 = nn.Sequential(
|
| | nn.Conv2d(1, 32, kernel_size=3, padding=1),
|
| | nn.BatchNorm2d(32),
|
| | nn.ReLU(),
|
| | nn.Conv2d(32, 32, kernel_size=3, padding=1),
|
| | nn.BatchNorm2d(32),
|
| | nn.ReLU()
|
| | )
|
| |
|
| | self.encoder2 = nn.Sequential(
|
| | nn.MaxPool2d(2),
|
| | nn.Conv2d(32, 64, kernel_size=3, padding=1),
|
| | nn.BatchNorm2d(64),
|
| | nn.ReLU(),
|
| | nn.Conv2d(64, 64, kernel_size=3, padding=1),
|
| | nn.BatchNorm2d(64),
|
| | nn.ReLU()
|
| | )
|
| |
|
| | self.encoder3 = nn.Sequential(
|
| | nn.MaxPool2d(2),
|
| | nn.Conv2d(64, 128, kernel_size=3, padding=1),
|
| | nn.BatchNorm2d(128),
|
| | nn.ReLU(),
|
| | nn.Conv2d(128, 128, kernel_size=3, padding=1),
|
| | nn.BatchNorm2d(128),
|
| | nn.ReLU()
|
| | )
|
| |
|
| |
|
| | self.bottleneck = nn.Sequential(
|
| | nn.MaxPool2d(2),
|
| | nn.Conv2d(128, 256, kernel_size=3, padding=1),
|
| | nn.BatchNorm2d(256),
|
| | nn.ReLU()
|
| | )
|
| |
|
| |
|
| | self.decoder3 = nn.Sequential(
|
| | nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2),
|
| | nn.Conv2d(256, 128, kernel_size=3, padding=1),
|
| | nn.BatchNorm2d(128),
|
| | nn.ReLU()
|
| | )
|
| |
|
| | self.decoder2 = nn.Sequential(
|
| | nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2),
|
| | nn.Conv2d(128, 64, kernel_size=3, padding=1),
|
| | nn.BatchNorm2d(64),
|
| | nn.ReLU()
|
| | )
|
| |
|
| | self.decoder1 = nn.Sequential(
|
| | nn.ConvTranspose2d(64, 32, kernel_size=2, stride=2),
|
| | nn.Conv2d(64, 32, kernel_size=3, padding=1),
|
| | nn.BatchNorm2d(32),
|
| | nn.ReLU()
|
| | )
|
| |
|
| |
|
| | self.output = nn.Sequential(
|
| | nn.Conv2d(32, 1, kernel_size=1),
|
| | nn.Sigmoid()
|
| | )
|
| |
|
| | def forward(self, x):
|
| | """
|
| | Args:
|
| | x: [B, 1, F, T] - 输入频谱幅度
|
| | Returns:
|
| | mask: [B, 1, F, T] - 二进制残差掩码
|
| | """
|
| |
|
| | e1 = self.encoder1(x)
|
| | e2 = self.encoder2(e1)
|
| | e3 = self.encoder3(e2)
|
| |
|
| |
|
| | b = self.bottleneck(e3)
|
| |
|
| |
|
| | d3 = self.decoder3(b)
|
| | d3 = torch.cat([d3, e3], dim=1)
|
| |
|
| | d2 = self.decoder2(d3)
|
| | d2 = torch.cat([d2, e2], dim=1)
|
| |
|
| | d1 = self.decoder1(d2)
|
| | d1 = torch.cat([d1, e1], dim=1)
|
| |
|
| |
|
| | mask = self.output(d1)
|
| | return mask
|
| |
|
| |
|
| | def advanced_dereverb(
|
| | audio: np.ndarray,
|
| | sr: int = 16000,
|
| | n_fft: int = 2048,
|
| | hop_length: int = 512,
|
| | device: str = "cuda"
|
| | ) -> Tuple[np.ndarray, np.ndarray]:
|
| | """
|
| | 高级去混响 - 分离干声和混响
|
| |
|
| | Args:
|
| | audio: 输入音频 [samples]
|
| | sr: 采样率
|
| | n_fft: FFT大小
|
| | hop_length: 跳跃长度
|
| | device: 计算设备
|
| |
|
| | Returns:
|
| | dry_signal: 干声(直达声)
|
| | reverb_tail: 混响尾巴
|
| | """
|
| | import librosa
|
| |
|
| |
|
| | spec = librosa.stft(audio, n_fft=n_fft, hop_length=hop_length, win_length=n_fft)
|
| | mag = np.abs(spec).astype(np.float32)
|
| | phase = np.angle(spec)
|
| |
|
| |
|
| |
|
| | rms = librosa.feature.rms(y=audio, frame_length=n_fft, hop_length=hop_length, center=True)[0]
|
| | rms_db = 20.0 * np.log10(rms + 1e-8)
|
| | ref_db = float(np.percentile(rms_db, 90))
|
| |
|
| |
|
| |
|
| | late_reflections = np.zeros_like(mag, dtype=np.float32)
|
| |
|
| | for t in range(2, mag.shape[1]):
|
| |
|
| | late_reflections[:, t] = np.maximum(
|
| | late_reflections[:, t - 1] * 0.92,
|
| | mag[:, t - 2] * 0.80
|
| | )
|
| |
|
| |
|
| |
|
| | direct_path = np.maximum(mag - 0.75 * late_reflections, 0.0)
|
| |
|
| |
|
| |
|
| | if len(rms) < mag.shape[1]:
|
| | rms_extended = np.pad(rms, (0, mag.shape[1] - len(rms)), mode='edge')
|
| | else:
|
| | rms_extended = rms[:mag.shape[1]]
|
| |
|
| |
|
| |
|
| | vocal_strength = np.clip((rms_db[:len(rms_extended)] - (ref_db - 35.0)) / 25.0, 0.0, 1.0)
|
| |
|
| |
|
| | reverb_ratio = np.clip(late_reflections / (mag + 1e-8), 0.0, 1.0)
|
| | floor_coef = 0.08 + 0.12 * vocal_strength[np.newaxis, :]
|
| | floor = (1.0 - reverb_ratio) * floor_coef * mag
|
| | direct_path = np.maximum(direct_path, floor)
|
| |
|
| |
|
| | kernel = np.array([1, 2, 3, 2, 1], dtype=np.float32)
|
| | kernel /= np.sum(kernel)
|
| | direct_path = np.apply_along_axis(
|
| | lambda row: np.convolve(row, kernel, mode="same"),
|
| | axis=1,
|
| | arr=direct_path,
|
| | )
|
| | direct_path = np.clip(direct_path, 0.0, mag)
|
| |
|
| |
|
| | reverb_mag = mag - direct_path
|
| | reverb_mag = np.maximum(reverb_mag, 0.0)
|
| |
|
| |
|
| |
|
| | dry_spec = direct_path * np.exp(1j * phase)
|
| | dry_signal = librosa.istft(dry_spec, hop_length=hop_length, win_length=n_fft, length=len(audio))
|
| |
|
| |
|
| | reverb_spec = reverb_mag * np.exp(1j * phase)
|
| | reverb_tail = librosa.istft(reverb_spec, hop_length=hop_length, win_length=n_fft, length=len(audio))
|
| |
|
| | return dry_signal.astype(np.float32), reverb_tail.astype(np.float32)
|
| |
|
| |
|
| | def apply_reverb_to_converted(
|
| | converted_dry: np.ndarray,
|
| | original_reverb: np.ndarray,
|
| | mix_ratio: float = 0.8
|
| | ) -> np.ndarray:
|
| | """
|
| | 将原始混响重新应用到转换后的干声上
|
| |
|
| | Args:
|
| | converted_dry: 转换后的干声
|
| | original_reverb: 原始混响尾巴
|
| | mix_ratio: 混响混合比例 (0-1)
|
| |
|
| | Returns:
|
| | wet_signal: 带混响的转换结果
|
| | """
|
| |
|
| | min_len = min(len(converted_dry), len(original_reverb))
|
| | converted_dry = converted_dry[:min_len]
|
| | original_reverb = original_reverb[:min_len]
|
| |
|
| |
|
| | wet_signal = converted_dry + mix_ratio * original_reverb
|
| |
|
| |
|
| | from lib.audio import soft_clip
|
| | wet_signal = soft_clip(wet_signal, threshold=0.9, ceiling=0.99)
|
| |
|
| | return wet_signal.astype(np.float32)
|
| |
|
| |
|
| | if __name__ == "__main__":
|
| |
|
| | print("Testing advanced dereverberation...")
|
| |
|
| |
|
| | sr = 16000
|
| | duration = 2.0
|
| | t = np.linspace(0, duration, int(sr * duration))
|
| |
|
| |
|
| | dry = np.sin(2 * np.pi * 440 * t).astype(np.float32)
|
| |
|
| |
|
| | reverb = np.zeros_like(dry)
|
| | delay_samples = int(0.05 * sr)
|
| | for i in range(3):
|
| | delay = delay_samples * (i + 1)
|
| | decay = 0.5 ** (i + 1)
|
| | if delay < len(reverb):
|
| | reverb[delay:] += dry[:-delay] * decay
|
| |
|
| |
|
| | wet = dry + reverb * 0.5
|
| |
|
| |
|
| | dry_extracted, reverb_extracted = advanced_dereverb(wet, sr)
|
| |
|
| | print(f"Input RMS: {np.sqrt(np.mean(wet**2)):.4f}")
|
| | print(f"Dry RMS: {np.sqrt(np.mean(dry_extracted**2)):.4f}")
|
| | print(f"Reverb RMS: {np.sqrt(np.mean(reverb_extracted**2)):.4f}")
|
| | print(f"Separation ratio: {np.sqrt(np.mean(dry_extracted**2)) / (np.sqrt(np.mean(reverb_extracted**2)) + 1e-8):.2f}")
|
| |
|
| | print("\n[OK] Advanced dereverberation test passed!")
|
| |
|