import os import numpy as np import torch import torch.nn as nn import torch.nn.functional as F import torchaudio import gradio as gr import asyncio from asyncio import base_events # ========================================== # PATCH ASYNCIO EVENT LOOP DESTRUCTOR # (Silence "Invalid file descriptor: -1" on shutdown) # ========================================== _original_del = base_events.BaseEventLoop.__del__ def _safe_event_loop_del(self, _orig=_original_del): try: _orig(self) except Exception: # Swallow any cleanup errors like: # ValueError: Invalid file descriptor: -1 pass base_events.BaseEventLoop.__del__ = _safe_event_loop_del # ========================================== # CONFIG # ========================================== CONFIG = { "sample_rate": 16000, "n_fft": 1024, "hop_length": 256, "n_mels": 80, "model_path": "best_denoiser_model.pth", # Classical noise-reduction parameters "noise_reduction_strength": 3.0, # 1.0 = mild, 2–3 = stronger "noise_floor_mult": 0.2, # how much noise to leave (0.1 ≈ -20 dB) } DEVICE = "cuda" if torch.cuda.is_available() else "cpu" print(f"[app] Using device: {DEVICE}") # ========================================== # MODEL (must match training) # ========================================== class ResidualBlock(nn.Module): def __init__(self, in_channels, out_channels): super().__init__() self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1) self.bn1 = nn.BatchNorm2d(out_channels) self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1) self.bn2 = nn.BatchNorm2d(out_channels) self.shortcut = nn.Sequential() if in_channels != out_channels: self.shortcut = nn.Sequential( nn.Conv2d(in_channels, out_channels, kernel_size=1), nn.BatchNorm2d(out_channels), ) def forward(self, x): residual = self.shortcut(x) out = F.relu(self.bn1(self.conv1(x))) out = self.bn2(self.conv2(out)) out += residual return F.relu(out) class AdvancedResUNet(nn.Module): def __init__(self): super().__init__() # Encoder self.enc1 = ResidualBlock(1, 32) self.pool1 = nn.MaxPool2d(2) self.enc2 = ResidualBlock(32, 64) self.pool2 = nn.MaxPool2d(2) self.enc3 = ResidualBlock(64, 128) self.pool3 = nn.MaxPool2d(2) # Bottleneck self.bottleneck = ResidualBlock(128, 256) # Decoder self.up3 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2) self.dec3 = ResidualBlock(256, 128) self.up2 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2) self.dec2 = ResidualBlock(128, 64) self.up1 = nn.ConvTranspose2d(64, 32, kernel_size=2, stride=2) self.dec1 = ResidualBlock(64, 32) self.final_conv = nn.Conv2d(32, 1, kernel_size=1) self.sigmoid = nn.Sigmoid() def forward(self, x): # Model outputs a MASK in log-mel space (0–1) e1 = self.enc1(x) p1 = self.pool1(e1) e2 = self.enc2(p1) p2 = self.pool2(e2) e3 = self.enc3(p2) p3 = self.pool3(e3) b = self.bottleneck(p3) d3 = self.up3(b) if d3.shape != e3.shape: d3 = F.interpolate(d3, size=e3.shape[2:]) d3 = torch.cat([d3, e3], dim=1) d3 = self.dec3(d3) d2 = self.up2(d3) if d2.shape != e2.shape: d2 = F.interpolate(d2, size=e2.shape[2:]) d2 = torch.cat([d2, e2], dim=1) d2 = self.dec2(d2) d1 = self.up1(d2) if d1.shape != e1.shape: d1 = F.interpolate(d1, size=e1.shape[2:]) d1 = torch.cat([d1, e1], dim=1) d1 = self.dec1(d1) mask = self.sigmoid(self.final_conv(d1)) # (B, 1, n_mels, frames) # IMPORTANT: return MASK, not x * mask return mask # ========================================== # GLOBALS: model + transforms # ========================================== _model = None _mel_scale = None _inv_mel = None _window = None def load_model_and_transforms(): global _model, _mel_scale, _inv_mel, _window if _model is None: print(f"[app] Loading model from {CONFIG['model_path']} on {DEVICE}...") model = AdvancedResUNet().to(DEVICE) if not os.path.exists(CONFIG["model_path"]): print( f"[app] WARNING: Model checkpoint not found at {CONFIG['model_path']}." ) print("[app] Denoiser will use ONLY classical spectral gating.") _model = None else: state_dict = torch.load(CONFIG["model_path"], map_location=DEVICE) model.load_state_dict(state_dict) model.eval() _model = model print("[app] Model loaded successfully.") # Mel / inverse-mel n_stft = CONFIG["n_fft"] // 2 + 1 _mel_scale = torchaudio.transforms.MelScale( n_mels=CONFIG["n_mels"], sample_rate=CONFIG["sample_rate"], n_stft=n_stft, ).to(DEVICE) _inv_mel = torchaudio.transforms.InverseMelScale( n_stft=n_stft, n_mels=CONFIG["n_mels"], sample_rate=CONFIG["sample_rate"], ).to(DEVICE) _window = torch.hann_window(CONFIG["n_fft"]).to(DEVICE) return _model, _mel_scale, _inv_mel, _window def make_loud_and_clear(waveform: torch.Tensor) -> torch.Tensor: """Normalize amplitude to avoid clipping / very low volume.""" max_val = waveform.abs().max() if max_val > 0: waveform = waveform / max_val return waveform # ========================================== # CLASSICAL SPECTRAL NOISE REDUCTION # ========================================== def apply_spectral_noise_gate( mag: torch.Tensor, sr: int, hop_length: int, strength: float = 2.0, noise_floor_mult: float = 0.1, ) -> torch.Tensor: """ Simple spectral subtraction / gating using a noise estimate from the first 0.5s. mag: (B, n_freq, frames) Returns: (B, n_freq, frames) denoised magnitude """ B, n_freq, n_frames = mag.shape # Use first 0.5 seconds as "noise only" region est_frames = max(1, int(0.5 * sr / hop_length)) est_frames = min(est_frames, n_frames) noise_profile = mag[:, :, :est_frames].mean(dim=-1, keepdim=True) # (B, n_freq, 1) # Subtract scaled noise profile denoised = mag - strength * noise_profile # Never go below some fraction of noise_profile (avoid musical noise) floor = noise_floor_mult * noise_profile denoised = torch.maximum(denoised, floor) return denoised # ========================================== # CORE DENOISING (phase-preserving, ISTFT) # ========================================== def denoise_waveform_tensor(waveform: torch.Tensor, sr: int): """ waveform: (1, T) on CPU sr: original sample rate of input Returns: (waveform_out (1, T_out) CPU float32 in [-1, 1], sr_out) """ model, mel_scale, inv_mel, window = load_model_and_transforms() use_model = model is not None orig_sr = sr work_sr = CONFIG["sample_rate"] # 1) Resample to working SR if needed if orig_sr != work_sr: resampler = torchaudio.transforms.Resample(orig_sr, work_sr) waveform = resampler(waveform) sr = work_sr else: sr = orig_sr # Ensure (1, T), float if waveform.dim() == 1: waveform = waveform.unsqueeze(0) waveform = waveform.to(DEVICE).float() # 2) STFT to get complex representation stft_complex = torch.stft( waveform, n_fft=CONFIG["n_fft"], hop_length=CONFIG["hop_length"], window=window, return_complex=True, ) # (1, n_freq, frames) noisy_mag = stft_complex.abs() # (1, n_freq, frames) noisy_phase = torch.angle(stft_complex) # (1, n_freq, frames) # ===== A) U-NET MASK IN MEL SPACE (OPTIONAL) ===== if use_model: noisy_mel = mel_scale(noisy_mag) # (1, n_mels, frames) noisy_log_mel = torch.log1p(noisy_mel + 1e-6) # (1, n_mels, frames) model_in = noisy_log_mel.unsqueeze(1) # (1, 1, n_mels, frames) with torch.no_grad(): mel_mask = model(model_in) # (1, 1, n_mels, frames) mel_mask = mel_mask.squeeze(1) # (1, n_mels, frames) denoised_log_mel = noisy_log_mel * mel_mask denoised_mel = torch.expm1(denoised_log_mel) denoised_mel = torch.clamp(denoised_mel, min=0.0) mag_for_gate = inv_mel(denoised_mel) # (1, n_freq, frames) else: mag_for_gate = noisy_mag # ===== B) CLASSICAL SPECTRAL NOISE GATE ===== mag_denoised = apply_spectral_noise_gate( mag_for_gate, sr=sr, hop_length=CONFIG["hop_length"], strength=CONFIG["noise_reduction_strength"], noise_floor_mult=CONFIG["noise_floor_mult"], ) # 6) Use original phase (for natural speech) complex_pred = mag_denoised * torch.exp(1j * noisy_phase) # 7) ISTFT back to waveform recon = torch.istft( complex_pred, n_fft=CONFIG["n_fft"], hop_length=CONFIG["hop_length"], window=window, length=waveform.shape[-1], ) # (1, T) or (T,) if recon.dim() == 1: recon = recon.unsqueeze(0) # 8) Resample BACK to original sample rate for playback if orig_sr != work_sr: resampler_back = torchaudio.transforms.Resample(work_sr, orig_sr) recon = resampler_back(recon) sr_out = orig_sr else: sr_out = work_sr recon = recon.cpu() recon = make_loud_and_clear(recon) return recon, sr_out # ========================================== # GRADIO INTERFACE # ========================================== def denoise_gradio(audio): """ Gradio passes: (sample_rate, np.ndarray) We return the same format: (sample_rate, np.ndarray) """ if audio is None: return None sample_rate, data = audio # data: (T,) or (T, C) if data.ndim == 2: # convert to mono: average channels data = data.mean(axis=1) # (T,) -> (1, T) wav_np = data.astype(np.float32)[None, :] # (1, T) wav_tensor = torch.from_numpy(wav_np) # CPU tensor denoised_tensor, out_sr = denoise_waveform_tensor(wav_tensor, sample_rate) denoised_np = denoised_tensor.squeeze(0).numpy().astype(np.float32) # IMPORTANT: return out_sr, not the original sample_rate blindly return (out_sr, denoised_np) with gr.Blocks() as demo: gr.Markdown( """ # 🎧 Speech Denoiser (U-Net + Spectral Noise Gate) Upload or record a noisy speech file, and this model will try to remove background noise while keeping the voice natural. - Uses a learned U-Net mask in mel space (if a model checkpoint is found) - Plus a classical STFT-based noise gate for stronger suppression """ ) with gr.Row(): inp = gr.Audio(type="numpy", label="Upload or record noisy audio") out = gr.Audio(type="numpy", label="Denoised output") btn = gr.Button("Denoise") btn.click(fn=denoise_gradio, inputs=inp, outputs=out) # This is what Hugging Face Spaces calls if __name__ == "__main__": demo.launch()