Spaces:
Running
Running
| import os | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import torchaudio | |
| import gradio as gr | |
| import asyncio | |
| from asyncio import base_events | |
| # ========================================== | |
| # PATCH ASYNCIO EVENT LOOP DESTRUCTOR | |
| # (Silence "Invalid file descriptor: -1" on shutdown) | |
| # ========================================== | |
| _original_del = base_events.BaseEventLoop.__del__ | |
| def _safe_event_loop_del(self, _orig=_original_del): | |
| try: | |
| _orig(self) | |
| except Exception: | |
| # Swallow any cleanup errors like: | |
| # ValueError: Invalid file descriptor: -1 | |
| pass | |
| base_events.BaseEventLoop.__del__ = _safe_event_loop_del | |
| # ========================================== | |
| # CONFIG | |
| # ========================================== | |
| CONFIG = { | |
| "sample_rate": 16000, | |
| "n_fft": 1024, | |
| "hop_length": 256, | |
| "n_mels": 80, | |
| "model_path": "best_denoiser_model.pth", | |
| # Classical noise-reduction parameters | |
| "noise_reduction_strength": 3.0, # 1.0 = mild, 2–3 = stronger | |
| "noise_floor_mult": 0.2, # how much noise to leave (0.1 ≈ -20 dB) | |
| } | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| print(f"[app] Using device: {DEVICE}") | |
| # ========================================== | |
| # MODEL (must match training) | |
| # ========================================== | |
| class ResidualBlock(nn.Module): | |
| def __init__(self, in_channels, out_channels): | |
| super().__init__() | |
| self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1) | |
| self.bn1 = nn.BatchNorm2d(out_channels) | |
| self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1) | |
| self.bn2 = nn.BatchNorm2d(out_channels) | |
| self.shortcut = nn.Sequential() | |
| if in_channels != out_channels: | |
| self.shortcut = nn.Sequential( | |
| nn.Conv2d(in_channels, out_channels, kernel_size=1), | |
| nn.BatchNorm2d(out_channels), | |
| ) | |
| def forward(self, x): | |
| residual = self.shortcut(x) | |
| out = F.relu(self.bn1(self.conv1(x))) | |
| out = self.bn2(self.conv2(out)) | |
| out += residual | |
| return F.relu(out) | |
| class AdvancedResUNet(nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| # Encoder | |
| self.enc1 = ResidualBlock(1, 32) | |
| self.pool1 = nn.MaxPool2d(2) | |
| self.enc2 = ResidualBlock(32, 64) | |
| self.pool2 = nn.MaxPool2d(2) | |
| self.enc3 = ResidualBlock(64, 128) | |
| self.pool3 = nn.MaxPool2d(2) | |
| # Bottleneck | |
| self.bottleneck = ResidualBlock(128, 256) | |
| # Decoder | |
| self.up3 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2) | |
| self.dec3 = ResidualBlock(256, 128) | |
| self.up2 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2) | |
| self.dec2 = ResidualBlock(128, 64) | |
| self.up1 = nn.ConvTranspose2d(64, 32, kernel_size=2, stride=2) | |
| self.dec1 = ResidualBlock(64, 32) | |
| self.final_conv = nn.Conv2d(32, 1, kernel_size=1) | |
| self.sigmoid = nn.Sigmoid() | |
| def forward(self, x): | |
| # Model outputs a MASK in log-mel space (0–1) | |
| e1 = self.enc1(x) | |
| p1 = self.pool1(e1) | |
| e2 = self.enc2(p1) | |
| p2 = self.pool2(e2) | |
| e3 = self.enc3(p2) | |
| p3 = self.pool3(e3) | |
| b = self.bottleneck(p3) | |
| d3 = self.up3(b) | |
| if d3.shape != e3.shape: | |
| d3 = F.interpolate(d3, size=e3.shape[2:]) | |
| d3 = torch.cat([d3, e3], dim=1) | |
| d3 = self.dec3(d3) | |
| d2 = self.up2(d3) | |
| if d2.shape != e2.shape: | |
| d2 = F.interpolate(d2, size=e2.shape[2:]) | |
| d2 = torch.cat([d2, e2], dim=1) | |
| d2 = self.dec2(d2) | |
| d1 = self.up1(d2) | |
| if d1.shape != e1.shape: | |
| d1 = F.interpolate(d1, size=e1.shape[2:]) | |
| d1 = torch.cat([d1, e1], dim=1) | |
| d1 = self.dec1(d1) | |
| mask = self.sigmoid(self.final_conv(d1)) # (B, 1, n_mels, frames) | |
| # IMPORTANT: return MASK, not x * mask | |
| return mask | |
| # ========================================== | |
| # GLOBALS: model + transforms | |
| # ========================================== | |
| _model = None | |
| _mel_scale = None | |
| _inv_mel = None | |
| _window = None | |
| def load_model_and_transforms(): | |
| global _model, _mel_scale, _inv_mel, _window | |
| if _model is None: | |
| print(f"[app] Loading model from {CONFIG['model_path']} on {DEVICE}...") | |
| model = AdvancedResUNet().to(DEVICE) | |
| if not os.path.exists(CONFIG["model_path"]): | |
| print( | |
| f"[app] WARNING: Model checkpoint not found at {CONFIG['model_path']}." | |
| ) | |
| print("[app] Denoiser will use ONLY classical spectral gating.") | |
| _model = None | |
| else: | |
| state_dict = torch.load(CONFIG["model_path"], map_location=DEVICE) | |
| model.load_state_dict(state_dict) | |
| model.eval() | |
| _model = model | |
| print("[app] Model loaded successfully.") | |
| # Mel / inverse-mel | |
| n_stft = CONFIG["n_fft"] // 2 + 1 | |
| _mel_scale = torchaudio.transforms.MelScale( | |
| n_mels=CONFIG["n_mels"], | |
| sample_rate=CONFIG["sample_rate"], | |
| n_stft=n_stft, | |
| ).to(DEVICE) | |
| _inv_mel = torchaudio.transforms.InverseMelScale( | |
| n_stft=n_stft, | |
| n_mels=CONFIG["n_mels"], | |
| sample_rate=CONFIG["sample_rate"], | |
| ).to(DEVICE) | |
| _window = torch.hann_window(CONFIG["n_fft"]).to(DEVICE) | |
| return _model, _mel_scale, _inv_mel, _window | |
| def make_loud_and_clear(waveform: torch.Tensor) -> torch.Tensor: | |
| """Normalize amplitude to avoid clipping / very low volume.""" | |
| max_val = waveform.abs().max() | |
| if max_val > 0: | |
| waveform = waveform / max_val | |
| return waveform | |
| # ========================================== | |
| # CLASSICAL SPECTRAL NOISE REDUCTION | |
| # ========================================== | |
| def apply_spectral_noise_gate( | |
| mag: torch.Tensor, | |
| sr: int, | |
| hop_length: int, | |
| strength: float = 2.0, | |
| noise_floor_mult: float = 0.1, | |
| ) -> torch.Tensor: | |
| """ | |
| Simple spectral subtraction / gating using a noise estimate from the first 0.5s. | |
| mag: (B, n_freq, frames) | |
| Returns: (B, n_freq, frames) denoised magnitude | |
| """ | |
| B, n_freq, n_frames = mag.shape | |
| # Use first 0.5 seconds as "noise only" region | |
| est_frames = max(1, int(0.5 * sr / hop_length)) | |
| est_frames = min(est_frames, n_frames) | |
| noise_profile = mag[:, :, :est_frames].mean(dim=-1, keepdim=True) # (B, n_freq, 1) | |
| # Subtract scaled noise profile | |
| denoised = mag - strength * noise_profile | |
| # Never go below some fraction of noise_profile (avoid musical noise) | |
| floor = noise_floor_mult * noise_profile | |
| denoised = torch.maximum(denoised, floor) | |
| return denoised | |
| # ========================================== | |
| # CORE DENOISING (phase-preserving, ISTFT) | |
| # ========================================== | |
| def denoise_waveform_tensor(waveform: torch.Tensor, sr: int): | |
| """ | |
| waveform: (1, T) on CPU | |
| sr: original sample rate of input | |
| Returns: (waveform_out (1, T_out) CPU float32 in [-1, 1], sr_out) | |
| """ | |
| model, mel_scale, inv_mel, window = load_model_and_transforms() | |
| use_model = model is not None | |
| orig_sr = sr | |
| work_sr = CONFIG["sample_rate"] | |
| # 1) Resample to working SR if needed | |
| if orig_sr != work_sr: | |
| resampler = torchaudio.transforms.Resample(orig_sr, work_sr) | |
| waveform = resampler(waveform) | |
| sr = work_sr | |
| else: | |
| sr = orig_sr | |
| # Ensure (1, T), float | |
| if waveform.dim() == 1: | |
| waveform = waveform.unsqueeze(0) | |
| waveform = waveform.to(DEVICE).float() | |
| # 2) STFT to get complex representation | |
| stft_complex = torch.stft( | |
| waveform, | |
| n_fft=CONFIG["n_fft"], | |
| hop_length=CONFIG["hop_length"], | |
| window=window, | |
| return_complex=True, | |
| ) # (1, n_freq, frames) | |
| noisy_mag = stft_complex.abs() # (1, n_freq, frames) | |
| noisy_phase = torch.angle(stft_complex) # (1, n_freq, frames) | |
| # ===== A) U-NET MASK IN MEL SPACE (OPTIONAL) ===== | |
| if use_model: | |
| noisy_mel = mel_scale(noisy_mag) # (1, n_mels, frames) | |
| noisy_log_mel = torch.log1p(noisy_mel + 1e-6) # (1, n_mels, frames) | |
| model_in = noisy_log_mel.unsqueeze(1) # (1, 1, n_mels, frames) | |
| with torch.no_grad(): | |
| mel_mask = model(model_in) # (1, 1, n_mels, frames) | |
| mel_mask = mel_mask.squeeze(1) # (1, n_mels, frames) | |
| denoised_log_mel = noisy_log_mel * mel_mask | |
| denoised_mel = torch.expm1(denoised_log_mel) | |
| denoised_mel = torch.clamp(denoised_mel, min=0.0) | |
| mag_for_gate = inv_mel(denoised_mel) # (1, n_freq, frames) | |
| else: | |
| mag_for_gate = noisy_mag | |
| # ===== B) CLASSICAL SPECTRAL NOISE GATE ===== | |
| mag_denoised = apply_spectral_noise_gate( | |
| mag_for_gate, | |
| sr=sr, | |
| hop_length=CONFIG["hop_length"], | |
| strength=CONFIG["noise_reduction_strength"], | |
| noise_floor_mult=CONFIG["noise_floor_mult"], | |
| ) | |
| # 6) Use original phase (for natural speech) | |
| complex_pred = mag_denoised * torch.exp(1j * noisy_phase) | |
| # 7) ISTFT back to waveform | |
| recon = torch.istft( | |
| complex_pred, | |
| n_fft=CONFIG["n_fft"], | |
| hop_length=CONFIG["hop_length"], | |
| window=window, | |
| length=waveform.shape[-1], | |
| ) # (1, T) or (T,) | |
| if recon.dim() == 1: | |
| recon = recon.unsqueeze(0) | |
| # 8) Resample BACK to original sample rate for playback | |
| if orig_sr != work_sr: | |
| resampler_back = torchaudio.transforms.Resample(work_sr, orig_sr) | |
| recon = resampler_back(recon) | |
| sr_out = orig_sr | |
| else: | |
| sr_out = work_sr | |
| recon = recon.cpu() | |
| recon = make_loud_and_clear(recon) | |
| return recon, sr_out | |
| # ========================================== | |
| # GRADIO INTERFACE | |
| # ========================================== | |
| def denoise_gradio(audio): | |
| """ | |
| Gradio passes: (sample_rate, np.ndarray) | |
| We return the same format: (sample_rate, np.ndarray) | |
| """ | |
| if audio is None: | |
| return None | |
| sample_rate, data = audio # data: (T,) or (T, C) | |
| if data.ndim == 2: | |
| # convert to mono: average channels | |
| data = data.mean(axis=1) | |
| # (T,) -> (1, T) | |
| wav_np = data.astype(np.float32)[None, :] # (1, T) | |
| wav_tensor = torch.from_numpy(wav_np) # CPU tensor | |
| denoised_tensor, out_sr = denoise_waveform_tensor(wav_tensor, sample_rate) | |
| denoised_np = denoised_tensor.squeeze(0).numpy().astype(np.float32) | |
| # IMPORTANT: return out_sr, not the original sample_rate blindly | |
| return (out_sr, denoised_np) | |
| with gr.Blocks() as demo: | |
| gr.Markdown( | |
| """ | |
| # 🎧 Speech Denoiser (U-Net + Spectral Noise Gate) | |
| Upload or record a noisy speech file, and this model will try to remove background noise | |
| while keeping the voice natural. | |
| - Uses a learned U-Net mask in mel space (if a model checkpoint is found) | |
| - Plus a classical STFT-based noise gate for stronger suppression | |
| """ | |
| ) | |
| with gr.Row(): | |
| inp = gr.Audio(type="numpy", label="Upload or record noisy audio") | |
| out = gr.Audio(type="numpy", label="Denoised output") | |
| btn = gr.Button("Denoise") | |
| btn.click(fn=denoise_gradio, inputs=inp, outputs=out) | |
| # This is what Hugging Face Spaces calls | |
| if __name__ == "__main__": | |
| demo.launch() | |