clearwave-ai / denoiser.py
sriiram18's picture
Update denoiser.py
81a9f2f verified
raw
history blame
8.24 kB
"""
Department 1 - Denoiser
Uses DeepFilterNet3 (deep learning) for noise removal.
✅ UPGRADED from noisereduce → DeepFilterNet3:
- AI-based speech enhancement (not just signal processing)
- 48000 Hz full-band audio support
- Stereo preserved if original is stereo
- Loudness normalisation target (-18 dB)
- PCM_24 high quality output
- Fallback to noisereduce if DeepFilterNet3 unavailable
"""
import os
import time
import subprocess
import numpy as np
import soundfile as sf
import logging
logger = logging.getLogger(__name__)
# ✅ 48000 Hz = DeepFilterNet3 native sample rate (full-band)
TARGET_SR = 48000
TARGET_LOUDNESS = -18.0
class Denoiser:
def __init__(self):
self.df_model = None
self.df_state = None
self._load_deepfilter()
def _load_deepfilter(self):
"""Try to load DeepFilterNet3. Falls back to noisereduce if unavailable."""
try:
from df import enhance, init_df
self.df_model, self.df_state, _ = init_df()
print("[Denoiser] ✅ DeepFilterNet3 loaded — AI-powered denoising active")
except ImportError:
print("[Denoiser] ⚠️ DeepFilterNet3 not installed.")
print("[Denoiser] Run: pip install deepfilterlib")
print("[Denoiser] ↩️ Falling back to noisereduce")
self.df_model = None
def process(self, audio_path: str, out_dir: str) -> str:
t0 = time.time()
# Step 1: Convert to high quality WAV
wav_path = os.path.join(out_dir, "input.wav")
self._convert_to_wav(audio_path, wav_path)
# Step 2: Read audio
audio, sr = sf.read(wav_path, always_2d=True)
original_channels = audio.shape[1]
# Step 3: Denoise — DeepFilterNet3 or fallback
if self.df_model is not None:
audio = self._denoise_deepfilter(audio, sr, original_channels)
else:
audio = self._denoise_noisereduce(audio, sr, original_channels)
# Step 4: Normalise loudness
audio = self._normalise(audio, sr)
# Step 5: Save at high quality (PCM_24)
out_path = os.path.join(out_dir, "denoised.wav")
sf.write(out_path, audio, sr, subtype="PCM_24")
elapsed = time.time() - t0
logger.info(f"[Denoiser] Done in {elapsed:.2f}s — {sr}Hz, {original_channels}ch")
print(f"[Denoiser] ✅ Done in {elapsed:.2f}s")
return out_path
# =========================================================
# ✅ PRIMARY: DeepFilterNet3 (AI-based, best quality)
# =========================================================
def _denoise_deepfilter(self, audio: np.ndarray, sr: int, channels: int) -> np.ndarray:
"""
Denoise using DeepFilterNet3.
DeepFilterNet3 works at 48kHz natively.
For stereo: process each channel separately, then recombine.
"""
try:
from df import enhance
import torch
# Resample to 48kHz if needed (DeepFilterNet3 native rate)
if sr != TARGET_SR:
audio = self._resample(audio, sr, TARGET_SR)
sr = TARGET_SR
if channels > 1:
# Stereo — process each channel independently
denoised_channels = []
for ch in range(channels):
channel = audio[:, ch].astype(np.float32)
# DeepFilterNet expects (1, samples) tensor
tensor = torch.from_numpy(channel).unsqueeze(0)
enhanced = enhance(self.df_model, self.df_state, tensor)
denoised_channels.append(enhanced.squeeze().numpy())
audio = np.stack(denoised_channels, axis=1)
else:
# Mono
channel = audio.squeeze().astype(np.float32)
tensor = torch.from_numpy(channel).unsqueeze(0)
enhanced = enhance(self.df_model, self.df_state, tensor)
audio = enhanced.squeeze().numpy()
print("[Denoiser] 🤖 DeepFilterNet3 enhancement complete")
return audio
except Exception as e:
logger.warning(f"[Denoiser] DeepFilterNet3 failed: {e}, falling back to noisereduce")
return self._denoise_noisereduce(audio, sr, channels)
# =========================================================
# ↩️ FALLBACK: noisereduce (signal processing)
# =========================================================
def _denoise_noisereduce(self, audio: np.ndarray, sr: int, channels: int) -> np.ndarray:
"""Fallback denoiser using noisereduce library."""
try:
import noisereduce as nr
print("[Denoiser] ↩️ Using noisereduce fallback")
if channels > 1:
denoised_channels = []
for ch in range(channels):
channel = audio[:, ch].astype(np.float32)
denoised = nr.reduce_noise(
y=channel,
sr=sr,
stationary=True,
prop_decrease=0.75, # less aggressive to preserve voice
).astype(np.float32)
denoised_channels.append(denoised)
audio = np.stack(denoised_channels, axis=1)
else:
audio = audio.squeeze().astype(np.float32)
audio = nr.reduce_noise(
y=audio,
sr=sr,
stationary=True,
prop_decrease=0.75,
).astype(np.float32)
return audio
except Exception as e:
logger.warning(f"[Denoiser] noisereduce also failed: {e}, returning raw audio")
return audio
# =========================================================
# 🔧 HELPERS
# =========================================================
def _resample(self, audio: np.ndarray, orig_sr: int, target_sr: int) -> np.ndarray:
"""Resample audio to target sample rate using scipy."""
try:
from scipy.signal import resample_poly
from math import gcd
g = gcd(orig_sr, target_sr)
up, down = target_sr // g, orig_sr // g
if audio.ndim > 1:
resampled = np.stack(
[resample_poly(audio[:, ch], up, down) for ch in range(audio.shape[1])],
axis=1
)
else:
resampled = resample_poly(audio, up, down)
return resampled.astype(np.float32)
except Exception as e:
logger.warning(f"[Denoiser] Resample failed: {e}")
return audio
def _convert_to_wav(self, src: str, dst: str):
"""Convert any audio format to high quality WAV at 48kHz."""
cmd = [
"ffmpeg", "-y", "-i", src,
"-acodec", "pcm_s24le",
"-ar", str(TARGET_SR), # 48kHz for DeepFilterNet3
dst
]
result = subprocess.run(cmd, capture_output=True, text=True)
if result.returncode != 0:
try:
data, sr = sf.read(src, always_2d=True)
sf.write(dst, data, sr, subtype="PCM_24")
except Exception as e:
raise RuntimeError(f"Cannot read audio file: {e}")
def _normalise(self, audio: np.ndarray, sr: int) -> np.ndarray:
"""Normalise to target loudness."""
try:
import pyloudnorm as pyln
meter = pyln.Meter(sr)
loudness = meter.integrated_loudness(audio)
if np.isfinite(loudness) and loudness < 0:
audio = pyln.normalize.loudness(audio, loudness, TARGET_LOUDNESS)
print(f"[Denoiser] Loudness: {loudness:.1f}dB → {TARGET_LOUDNESS}dB")
except Exception:
rms = np.sqrt(np.mean(audio ** 2))
if rms > 1e-9:
target = 10 ** (TARGET_LOUDNESS / 20.0)
audio = audio * (target / rms)
print(f"[Denoiser] RMS normalised to {TARGET_LOUDNESS}dB")
return np.clip(audio, -1.0, 1.0).astype(np.float32)