audio-denoiser / app.py
91prince's picture
Update app.py
b9b7efb verified
import os
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchaudio
import gradio as gr
import asyncio
from asyncio import base_events
# ==========================================
# PATCH ASYNCIO EVENT LOOP DESTRUCTOR
# (Silence "Invalid file descriptor: -1" on shutdown)
# ==========================================
_original_del = base_events.BaseEventLoop.__del__
def _safe_event_loop_del(self, _orig=_original_del):
try:
_orig(self)
except Exception:
# Swallow any cleanup errors like:
# ValueError: Invalid file descriptor: -1
pass
base_events.BaseEventLoop.__del__ = _safe_event_loop_del
# ==========================================
# CONFIG
# ==========================================
CONFIG = {
"sample_rate": 16000,
"n_fft": 1024,
"hop_length": 256,
"n_mels": 80,
"model_path": "best_denoiser_model.pth",
# Classical noise-reduction parameters
"noise_reduction_strength": 3.0, # 1.0 = mild, 2–3 = stronger
"noise_floor_mult": 0.2, # how much noise to leave (0.1 ≈ -20 dB)
}
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f"[app] Using device: {DEVICE}")
# ==========================================
# MODEL (must match training)
# ==========================================
class ResidualBlock(nn.Module):
def __init__(self, in_channels, out_channels):
super().__init__()
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
self.bn1 = nn.BatchNorm2d(out_channels)
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
self.bn2 = nn.BatchNorm2d(out_channels)
self.shortcut = nn.Sequential()
if in_channels != out_channels:
self.shortcut = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=1),
nn.BatchNorm2d(out_channels),
)
def forward(self, x):
residual = self.shortcut(x)
out = F.relu(self.bn1(self.conv1(x)))
out = self.bn2(self.conv2(out))
out += residual
return F.relu(out)
class AdvancedResUNet(nn.Module):
def __init__(self):
super().__init__()
# Encoder
self.enc1 = ResidualBlock(1, 32)
self.pool1 = nn.MaxPool2d(2)
self.enc2 = ResidualBlock(32, 64)
self.pool2 = nn.MaxPool2d(2)
self.enc3 = ResidualBlock(64, 128)
self.pool3 = nn.MaxPool2d(2)
# Bottleneck
self.bottleneck = ResidualBlock(128, 256)
# Decoder
self.up3 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
self.dec3 = ResidualBlock(256, 128)
self.up2 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
self.dec2 = ResidualBlock(128, 64)
self.up1 = nn.ConvTranspose2d(64, 32, kernel_size=2, stride=2)
self.dec1 = ResidualBlock(64, 32)
self.final_conv = nn.Conv2d(32, 1, kernel_size=1)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
# Model outputs a MASK in log-mel space (0–1)
e1 = self.enc1(x)
p1 = self.pool1(e1)
e2 = self.enc2(p1)
p2 = self.pool2(e2)
e3 = self.enc3(p2)
p3 = self.pool3(e3)
b = self.bottleneck(p3)
d3 = self.up3(b)
if d3.shape != e3.shape:
d3 = F.interpolate(d3, size=e3.shape[2:])
d3 = torch.cat([d3, e3], dim=1)
d3 = self.dec3(d3)
d2 = self.up2(d3)
if d2.shape != e2.shape:
d2 = F.interpolate(d2, size=e2.shape[2:])
d2 = torch.cat([d2, e2], dim=1)
d2 = self.dec2(d2)
d1 = self.up1(d2)
if d1.shape != e1.shape:
d1 = F.interpolate(d1, size=e1.shape[2:])
d1 = torch.cat([d1, e1], dim=1)
d1 = self.dec1(d1)
mask = self.sigmoid(self.final_conv(d1)) # (B, 1, n_mels, frames)
# IMPORTANT: return MASK, not x * mask
return mask
# ==========================================
# GLOBALS: model + transforms
# ==========================================
_model = None
_mel_scale = None
_inv_mel = None
_window = None
def load_model_and_transforms():
global _model, _mel_scale, _inv_mel, _window
if _model is None:
print(f"[app] Loading model from {CONFIG['model_path']} on {DEVICE}...")
model = AdvancedResUNet().to(DEVICE)
if not os.path.exists(CONFIG["model_path"]):
print(
f"[app] WARNING: Model checkpoint not found at {CONFIG['model_path']}."
)
print("[app] Denoiser will use ONLY classical spectral gating.")
_model = None
else:
state_dict = torch.load(CONFIG["model_path"], map_location=DEVICE)
model.load_state_dict(state_dict)
model.eval()
_model = model
print("[app] Model loaded successfully.")
# Mel / inverse-mel
n_stft = CONFIG["n_fft"] // 2 + 1
_mel_scale = torchaudio.transforms.MelScale(
n_mels=CONFIG["n_mels"],
sample_rate=CONFIG["sample_rate"],
n_stft=n_stft,
).to(DEVICE)
_inv_mel = torchaudio.transforms.InverseMelScale(
n_stft=n_stft,
n_mels=CONFIG["n_mels"],
sample_rate=CONFIG["sample_rate"],
).to(DEVICE)
_window = torch.hann_window(CONFIG["n_fft"]).to(DEVICE)
return _model, _mel_scale, _inv_mel, _window
def make_loud_and_clear(waveform: torch.Tensor) -> torch.Tensor:
"""Normalize amplitude to avoid clipping / very low volume."""
max_val = waveform.abs().max()
if max_val > 0:
waveform = waveform / max_val
return waveform
# ==========================================
# CLASSICAL SPECTRAL NOISE REDUCTION
# ==========================================
def apply_spectral_noise_gate(
mag: torch.Tensor,
sr: int,
hop_length: int,
strength: float = 2.0,
noise_floor_mult: float = 0.1,
) -> torch.Tensor:
"""
Simple spectral subtraction / gating using a noise estimate from the first 0.5s.
mag: (B, n_freq, frames)
Returns: (B, n_freq, frames) denoised magnitude
"""
B, n_freq, n_frames = mag.shape
# Use first 0.5 seconds as "noise only" region
est_frames = max(1, int(0.5 * sr / hop_length))
est_frames = min(est_frames, n_frames)
noise_profile = mag[:, :, :est_frames].mean(dim=-1, keepdim=True) # (B, n_freq, 1)
# Subtract scaled noise profile
denoised = mag - strength * noise_profile
# Never go below some fraction of noise_profile (avoid musical noise)
floor = noise_floor_mult * noise_profile
denoised = torch.maximum(denoised, floor)
return denoised
# ==========================================
# CORE DENOISING (phase-preserving, ISTFT)
# ==========================================
def denoise_waveform_tensor(waveform: torch.Tensor, sr: int):
"""
waveform: (1, T) on CPU
sr: original sample rate of input
Returns: (waveform_out (1, T_out) CPU float32 in [-1, 1], sr_out)
"""
model, mel_scale, inv_mel, window = load_model_and_transforms()
use_model = model is not None
orig_sr = sr
work_sr = CONFIG["sample_rate"]
# 1) Resample to working SR if needed
if orig_sr != work_sr:
resampler = torchaudio.transforms.Resample(orig_sr, work_sr)
waveform = resampler(waveform)
sr = work_sr
else:
sr = orig_sr
# Ensure (1, T), float
if waveform.dim() == 1:
waveform = waveform.unsqueeze(0)
waveform = waveform.to(DEVICE).float()
# 2) STFT to get complex representation
stft_complex = torch.stft(
waveform,
n_fft=CONFIG["n_fft"],
hop_length=CONFIG["hop_length"],
window=window,
return_complex=True,
) # (1, n_freq, frames)
noisy_mag = stft_complex.abs() # (1, n_freq, frames)
noisy_phase = torch.angle(stft_complex) # (1, n_freq, frames)
# ===== A) U-NET MASK IN MEL SPACE (OPTIONAL) =====
if use_model:
noisy_mel = mel_scale(noisy_mag) # (1, n_mels, frames)
noisy_log_mel = torch.log1p(noisy_mel + 1e-6) # (1, n_mels, frames)
model_in = noisy_log_mel.unsqueeze(1) # (1, 1, n_mels, frames)
with torch.no_grad():
mel_mask = model(model_in) # (1, 1, n_mels, frames)
mel_mask = mel_mask.squeeze(1) # (1, n_mels, frames)
denoised_log_mel = noisy_log_mel * mel_mask
denoised_mel = torch.expm1(denoised_log_mel)
denoised_mel = torch.clamp(denoised_mel, min=0.0)
mag_for_gate = inv_mel(denoised_mel) # (1, n_freq, frames)
else:
mag_for_gate = noisy_mag
# ===== B) CLASSICAL SPECTRAL NOISE GATE =====
mag_denoised = apply_spectral_noise_gate(
mag_for_gate,
sr=sr,
hop_length=CONFIG["hop_length"],
strength=CONFIG["noise_reduction_strength"],
noise_floor_mult=CONFIG["noise_floor_mult"],
)
# 6) Use original phase (for natural speech)
complex_pred = mag_denoised * torch.exp(1j * noisy_phase)
# 7) ISTFT back to waveform
recon = torch.istft(
complex_pred,
n_fft=CONFIG["n_fft"],
hop_length=CONFIG["hop_length"],
window=window,
length=waveform.shape[-1],
) # (1, T) or (T,)
if recon.dim() == 1:
recon = recon.unsqueeze(0)
# 8) Resample BACK to original sample rate for playback
if orig_sr != work_sr:
resampler_back = torchaudio.transforms.Resample(work_sr, orig_sr)
recon = resampler_back(recon)
sr_out = orig_sr
else:
sr_out = work_sr
recon = recon.cpu()
recon = make_loud_and_clear(recon)
return recon, sr_out
# ==========================================
# GRADIO INTERFACE
# ==========================================
def denoise_gradio(audio):
"""
Gradio passes: (sample_rate, np.ndarray)
We return the same format: (sample_rate, np.ndarray)
"""
if audio is None:
return None
sample_rate, data = audio # data: (T,) or (T, C)
if data.ndim == 2:
# convert to mono: average channels
data = data.mean(axis=1)
# (T,) -> (1, T)
wav_np = data.astype(np.float32)[None, :] # (1, T)
wav_tensor = torch.from_numpy(wav_np) # CPU tensor
denoised_tensor, out_sr = denoise_waveform_tensor(wav_tensor, sample_rate)
denoised_np = denoised_tensor.squeeze(0).numpy().astype(np.float32)
# IMPORTANT: return out_sr, not the original sample_rate blindly
return (out_sr, denoised_np)
with gr.Blocks() as demo:
gr.Markdown(
"""
# 🎧 Speech Denoiser (U-Net + Spectral Noise Gate)
Upload or record a noisy speech file, and this model will try to remove background noise
while keeping the voice natural.
- Uses a learned U-Net mask in mel space (if a model checkpoint is found)
- Plus a classical STFT-based noise gate for stronger suppression
"""
)
with gr.Row():
inp = gr.Audio(type="numpy", label="Upload or record noisy audio")
out = gr.Audio(type="numpy", label="Denoised output")
btn = gr.Button("Denoise")
btn.click(fn=denoise_gradio, inputs=inp, outputs=out)
# This is what Hugging Face Spaces calls
if __name__ == "__main__":
demo.launch()