AI-RVC / infer /advanced_dereverb.py
mason369's picture
Upload folder using huggingface_hub
b15e31b verified
# -*- coding: utf-8 -*-
"""
高级去混响模块 - 基于二进制残差掩码和时域一致性
参考: arXiv 2510.00356 - Dereverberation Using Binary Residual Masking
"""
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Tuple, Optional
class BinaryResidualMask(nn.Module):
"""
二进制残差掩码网络 - 专注于抑制混响而非预测完整频谱
核心思想:
1. 学习识别并抑制晚期反射(late reflections)
2. 保留直达声路径(direct path)
3. 使用时域一致性损失隐式学习相位
"""
def __init__(self, n_fft=2048, hop_length=512):
super().__init__()
self.n_fft = n_fft
self.hop_length = hop_length
self.freq_bins = n_fft // 2 + 1
# U-Net编码器
self.encoder1 = nn.Sequential(
nn.Conv2d(1, 32, kernel_size=3, padding=1),
nn.BatchNorm2d(32),
nn.ReLU(),
nn.Conv2d(32, 32, kernel_size=3, padding=1),
nn.BatchNorm2d(32),
nn.ReLU()
)
self.encoder2 = nn.Sequential(
nn.MaxPool2d(2),
nn.Conv2d(32, 64, kernel_size=3, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(),
nn.Conv2d(64, 64, kernel_size=3, padding=1),
nn.BatchNorm2d(64),
nn.ReLU()
)
self.encoder3 = nn.Sequential(
nn.MaxPool2d(2),
nn.Conv2d(64, 128, kernel_size=3, padding=1),
nn.BatchNorm2d(128),
nn.ReLU(),
nn.Conv2d(128, 128, kernel_size=3, padding=1),
nn.BatchNorm2d(128),
nn.ReLU()
)
# 瓶颈层 - 时序注意力
self.bottleneck = nn.Sequential(
nn.MaxPool2d(2),
nn.Conv2d(128, 256, kernel_size=3, padding=1),
nn.BatchNorm2d(256),
nn.ReLU()
)
# U-Net解码器
self.decoder3 = nn.Sequential(
nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2),
nn.Conv2d(256, 128, kernel_size=3, padding=1),
nn.BatchNorm2d(128),
nn.ReLU()
)
self.decoder2 = nn.Sequential(
nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2),
nn.Conv2d(128, 64, kernel_size=3, padding=1),
nn.BatchNorm2d(64),
nn.ReLU()
)
self.decoder1 = nn.Sequential(
nn.ConvTranspose2d(64, 32, kernel_size=2, stride=2),
nn.Conv2d(64, 32, kernel_size=3, padding=1),
nn.BatchNorm2d(32),
nn.ReLU()
)
# 输出层 - 二进制掩码
self.output = nn.Sequential(
nn.Conv2d(32, 1, kernel_size=1),
nn.Sigmoid() # 输出0-1的掩码
)
def forward(self, x):
"""
Args:
x: [B, 1, F, T] - 输入频谱幅度
Returns:
mask: [B, 1, F, T] - 二进制残差掩码
"""
# 编码
e1 = self.encoder1(x)
e2 = self.encoder2(e1)
e3 = self.encoder3(e2)
# 瓶颈
b = self.bottleneck(e3)
# 解码 + 跳跃连接
d3 = self.decoder3(b)
d3 = torch.cat([d3, e3], dim=1)
d2 = self.decoder2(d3)
d2 = torch.cat([d2, e2], dim=1)
d1 = self.decoder1(d2)
d1 = torch.cat([d1, e1], dim=1)
# 输出掩码
mask = self.output(d1)
return mask
def advanced_dereverb(
audio: np.ndarray,
sr: int = 16000,
n_fft: int = 2048,
hop_length: int = 512,
device: str = "cuda"
) -> Tuple[np.ndarray, np.ndarray]:
"""
高级去混响 - 分离干声和混响
Args:
audio: 输入音频 [samples]
sr: 采样率
n_fft: FFT大小
hop_length: 跳跃长度
device: 计算设备
Returns:
dry_signal: 干声(直达声)
reverb_tail: 混响尾巴
"""
import librosa
# STFT
spec = librosa.stft(audio, n_fft=n_fft, hop_length=hop_length, win_length=n_fft)
mag = np.abs(spec).astype(np.float32)
phase = np.angle(spec)
# 基于能量的混响检测
# 1. 计算时域RMS能量
rms = librosa.feature.rms(y=audio, frame_length=n_fft, hop_length=hop_length, center=True)[0]
rms_db = 20.0 * np.log10(rms + 1e-8)
ref_db = float(np.percentile(rms_db, 90))
# 2. 检测晚期反射(late reflections)
# 晚期反射特征:能量衰减 + 时间延迟
late_reflections = np.zeros_like(mag, dtype=np.float32)
for t in range(2, mag.shape[1]):
# 递归估计:衰减的历史 + 延迟的观测
late_reflections[:, t] = np.maximum(
late_reflections[:, t - 1] * 0.92, # 衰减系数
mag[:, t - 2] * 0.80 # 延迟观测
)
# 3. 计算直达声(direct path)
# 直达声 = 总能量 - 晚期反射
direct_path = np.maximum(mag - 0.75 * late_reflections, 0.0)
# 4. 动态floor:保护有声段
# 扩展RMS到频谱帧数
if len(rms) < mag.shape[1]:
rms_extended = np.pad(rms, (0, mag.shape[1] - len(rms)), mode='edge')
else:
rms_extended = rms[:mag.shape[1]]
# 有声段(高能量):vocal_strength接近1
# 无声段(低能量/混响尾):vocal_strength接近0
vocal_strength = np.clip((rms_db[:len(rms_extended)] - (ref_db - 35.0)) / 25.0, 0.0, 1.0)
# 动态floor系数
reverb_ratio = np.clip(late_reflections / (mag + 1e-8), 0.0, 1.0)
floor_coef = 0.08 + 0.12 * vocal_strength[np.newaxis, :]
floor = (1.0 - reverb_ratio) * floor_coef * mag
direct_path = np.maximum(direct_path, floor)
# 5. 时域平滑(避免音乐噪声)
kernel = np.array([1, 2, 3, 2, 1], dtype=np.float32)
kernel /= np.sum(kernel)
direct_path = np.apply_along_axis(
lambda row: np.convolve(row, kernel, mode="same"),
axis=1,
arr=direct_path,
)
direct_path = np.clip(direct_path, 0.0, mag)
# 6. 计算混响残差
reverb_mag = mag - direct_path
reverb_mag = np.maximum(reverb_mag, 0.0)
# 7. 重建音频
# 干声:使用原始相位(保留音色)
dry_spec = direct_path * np.exp(1j * phase)
dry_signal = librosa.istft(dry_spec, hop_length=hop_length, win_length=n_fft, length=len(audio))
# 混响:使用原始相位
reverb_spec = reverb_mag * np.exp(1j * phase)
reverb_tail = librosa.istft(reverb_spec, hop_length=hop_length, win_length=n_fft, length=len(audio))
return dry_signal.astype(np.float32), reverb_tail.astype(np.float32)
def apply_reverb_to_converted(
converted_dry: np.ndarray,
original_reverb: np.ndarray,
mix_ratio: float = 0.8
) -> np.ndarray:
"""
将原始混响重新应用到转换后的干声上
Args:
converted_dry: 转换后的干声
original_reverb: 原始混响尾巴
mix_ratio: 混响混合比例 (0-1)
Returns:
wet_signal: 带混响的转换结果
"""
# 对齐长度
min_len = min(len(converted_dry), len(original_reverb))
converted_dry = converted_dry[:min_len]
original_reverb = original_reverb[:min_len]
# 混合
wet_signal = converted_dry + mix_ratio * original_reverb
# 软限幅
from lib.audio import soft_clip
wet_signal = soft_clip(wet_signal, threshold=0.9, ceiling=0.99)
return wet_signal.astype(np.float32)
if __name__ == "__main__":
# 测试
print("Testing advanced dereverberation...")
# 生成测试信号:干声 + 混响
sr = 16000
duration = 2.0
t = np.linspace(0, duration, int(sr * duration))
# 干声:440Hz正弦波
dry = np.sin(2 * np.pi * 440 * t).astype(np.float32)
# 混响:衰减的延迟
reverb = np.zeros_like(dry)
delay_samples = int(0.05 * sr) # 50ms延迟
for i in range(3):
delay = delay_samples * (i + 1)
decay = 0.5 ** (i + 1)
if delay < len(reverb):
reverb[delay:] += dry[:-delay] * decay
# 混合信号
wet = dry + reverb * 0.5
# 去混响
dry_extracted, reverb_extracted = advanced_dereverb(wet, sr)
print(f"Input RMS: {np.sqrt(np.mean(wet**2)):.4f}")
print(f"Dry RMS: {np.sqrt(np.mean(dry_extracted**2)):.4f}")
print(f"Reverb RMS: {np.sqrt(np.mean(reverb_extracted**2)):.4f}")
print(f"Separation ratio: {np.sqrt(np.mean(dry_extracted**2)) / (np.sqrt(np.mean(reverb_extracted**2)) + 1e-8):.2f}")
print("\n[OK] Advanced dereverberation test passed!")