Spaces:
Sleeping
Sleeping
| import glob | |
| import math | |
| import os | |
| import tempfile | |
| import time | |
| from pathlib import Path | |
| from typing import List, Optional, Tuple, Union | |
| import subprocess | |
| from dataclasses import dataclass | |
| import gradio as gr | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| import torch | |
| from loguru import logger | |
| from PIL import Image as PILImage | |
| from torch import Tensor | |
| from torchaudio.backend.common import AudioMetaData | |
| from df import config | |
| from df.enhance import enhance, init_df, load_audio, save_audio | |
| from df.io import resample | |
| # ============================================================================ | |
| # Configuration and Setup | |
| # ============================================================================ | |
| class AppConfig: | |
| """Application configuration""" | |
| device: torch.device | |
| sample_rate: int = 48000 | |
| max_duration_seconds: int = 3600 | |
| cleanup_hours: int = 2 | |
| temp_dir: str = "/tmp" | |
| model_path: str = "./DeepFilterNet2" | |
| fade_duration: float = 0.15 | |
| class AudioProcessor: | |
| """Handles audio processing operations""" | |
| def __init__(self, model, df, config: AppConfig): | |
| self.model = model | |
| self.df = df | |
| self.config = config | |
| def mix_at_snr(self, clean: Tensor, noise: Tensor, snr: float, eps: float = 1e-10) -> Tuple[Tensor, Tensor, Tensor]: | |
| """Mix clean and noise signal at a given SNR with improved error handling. | |
| Args: | |
| clean: 1D Tensor with the clean signal to mix. | |
| noise: 1D Tensor of shape. | |
| snr: Signal to noise ratio in dB. | |
| eps: Small epsilon for numerical stability. | |
| Returns: | |
| clean: 1D Tensor with gain changed according to the snr. | |
| noise: 1D Tensor with the combined noise channels. | |
| mix: 1D Tensor with added clean and noise signals. | |
| """ | |
| clean = torch.as_tensor(clean).mean(0, keepdim=True) | |
| noise = torch.as_tensor(noise).mean(0, keepdim=True) | |
| # Repeat noise if shorter than clean signal | |
| if noise.shape[1] < clean.shape[1]: | |
| repeats = int(math.ceil(clean.shape[1] / noise.shape[1])) | |
| noise = noise.repeat((1, repeats)) | |
| # Random starting point for noise | |
| max_start = int(noise.shape[1] - clean.shape[1]) | |
| start = torch.randint(0, max_start, ()).item() if max_start > 0 else 0 | |
| noise = noise[:, start : start + clean.shape[1]] | |
| # Calculate SNR scaling | |
| E_speech = torch.mean(clean.pow(2)) + eps | |
| E_noise = torch.mean(noise.pow(2)) + eps | |
| K = torch.sqrt((E_noise / E_speech) * 10 ** (snr / 10) + eps) | |
| noise = noise / K | |
| mixture = clean + noise | |
| # Check for clipping | |
| assert torch.isfinite(mixture).all(), "Non-finite values detected in mixture" | |
| max_m = mixture.abs().max() | |
| if max_m > 1: | |
| logger.warning(f"Clipping detected during mixing. Reducing gain by {1/max_m:.3f}") | |
| clean, noise, mixture = clean / max_m, noise / max_m, mixture / max_m | |
| return clean, noise, mixture | |
| def enhance_audio(self, audio: Tensor) -> Tensor: | |
| """Enhance audio using the DeepFilterNet model. | |
| Args: | |
| audio: Input audio tensor | |
| Returns: | |
| Enhanced audio tensor | |
| """ | |
| logger.info(f"Enhancing audio with shape {audio.shape}") | |
| with torch.no_grad(): | |
| enhanced = enhance(self.model, self.df, audio) | |
| # Apply fade-in to avoid clicks | |
| sr = self.config.sample_rate | |
| fade_samples = int(sr * self.config.fade_duration) | |
| lim = torch.linspace(0.0, 1.0, fade_samples).unsqueeze(0) | |
| lim = torch.cat((lim, torch.ones(1, enhanced.shape[1] - lim.shape[1])), dim=1) | |
| enhanced = enhanced * lim | |
| return enhanced | |
| class AudioLoader: | |
| """Handles audio loading from various sources""" | |
| def ensure_wav(filepath: str) -> str: | |
| """Convert audio files to WAV using ffmpeg if needed. | |
| Args: | |
| filepath: Path to input audio file | |
| Returns: | |
| Path to WAV file | |
| """ | |
| if not filepath: | |
| return filepath | |
| file_ext = Path(filepath).suffix.lower() | |
| if file_ext in ['.mp3', '.m4a', '.ogg', '.flac', '.aac']: | |
| wav_path = str(Path(filepath).with_suffix('.wav')) | |
| try: | |
| subprocess.run( | |
| ["ffmpeg", "-y", "-i", filepath, "-acodec", "pcm_s16le", wav_path], | |
| check=True, | |
| capture_output=True | |
| ) | |
| logger.info(f"Converted {file_ext} to WAV: {wav_path}") | |
| return wav_path | |
| except subprocess.CalledProcessError as e: | |
| logger.error(f"FFmpeg conversion failed: {e.stderr}") | |
| raise | |
| return filepath | |
| def load_audio_gradio( | |
| audio_or_file: Union[None, str, Tuple[int, np.ndarray]], | |
| sr: int | |
| ) -> Optional[Tuple[Tensor, AudioMetaData]]: | |
| """Load audio from Gradio input (file path or recorded audio). | |
| Args: | |
| audio_or_file: Either a file path string or tuple of (sample_rate, audio_array) | |
| sr: Target sample rate | |
| Returns: | |
| Tuple of (audio tensor, metadata) or None | |
| """ | |
| if audio_or_file is None: | |
| return None | |
| if isinstance(audio_or_file, str): | |
| if audio_or_file.lower() == "none": | |
| return None | |
| # Load from file | |
| audio_or_file = AudioLoader.ensure_wav(audio_or_file) | |
| audio, meta = load_audio(audio_or_file, sr) | |
| else: | |
| # Load from Gradio recording | |
| meta = AudioMetaData(-1, -1, -1, -1, "") | |
| assert isinstance(audio_or_file, (tuple, list)) | |
| meta.sample_rate, audio_np = audio_or_file | |
| # Handle different array shapes | |
| audio_np = audio_np.reshape(audio_np.shape[0], -1).T | |
| # Convert to float32 | |
| if audio_np.dtype == np.int16: | |
| audio_np = (audio_np / (1 << 15)).astype(np.float32) | |
| elif audio_np.dtype == np.int32: | |
| audio_np = (audio_np / (1 << 31)).astype(np.float32) | |
| audio = resample(torch.from_numpy(audio_np), meta.sample_rate, sr) | |
| return audio, meta | |
| class SpectrogramVisualizer: | |
| """Handles spectrogram visualization""" | |
| def __init__(self, figsize: Tuple[float, float] = (15.2, 4)): | |
| self.figsize = figsize | |
| self.fig_noisy, self.ax_noisy = plt.subplots(figsize=figsize) | |
| self.fig_noisy.set_tight_layout(True) | |
| self.fig_enh, self.ax_enh = plt.subplots(figsize=figsize) | |
| self.fig_enh.set_tight_layout(True) | |
| def specshow( | |
| self, | |
| spec: Union[Tensor, np.ndarray], | |
| ax: Optional[plt.Axes] = None, | |
| title: Optional[str] = None, | |
| xlabel: Optional[str] = None, | |
| ylabel: Optional[str] = None, | |
| sr: int = 48000, | |
| n_fft: Optional[int] = None, | |
| hop: Optional[int] = None, | |
| vmin: float = -100, | |
| vmax: float = 0, | |
| cmap: str = "inferno", | |
| ): | |
| """Plot a spectrogram of shape [F, T]""" | |
| spec_np = spec.cpu().numpy() if isinstance(spec, torch.Tensor) else spec | |
| if n_fft is None: | |
| n_fft = spec.shape[0] * 2 if spec.shape[0] % 2 == 0 else (spec.shape[0] - 1) * 2 | |
| hop = hop or n_fft // 4 | |
| t = np.arange(0, spec_np.shape[-1]) * hop / sr | |
| f = np.arange(0, spec_np.shape[0]) * sr // 2 / (n_fft // 2) / 1000 | |
| im = ax.pcolormesh( | |
| t, f, spec_np, | |
| rasterized=True, | |
| shading="auto", | |
| vmin=vmin, | |
| vmax=vmax, | |
| cmap=cmap | |
| ) | |
| if title: | |
| ax.set_title(title) | |
| if xlabel: | |
| ax.set_xlabel(xlabel) | |
| if ylabel: | |
| ax.set_ylabel(ylabel) | |
| return im | |
| def create_spectrogram( | |
| self, | |
| audio: Tensor, | |
| figure: plt.Figure, | |
| ax: plt.Axes, | |
| sr: int = 48000, | |
| n_fft: int = 1024, | |
| hop: int = 512, | |
| title: Optional[str] = None, | |
| ) -> PILImage.Image: | |
| """Create spectrogram image from audio tensor""" | |
| audio = torch.as_tensor(audio) | |
| # Compute STFT | |
| w = torch.hann_window(n_fft, device=audio.device) | |
| spec = torch.stft(audio, n_fft, hop, window=w, return_complex=False) | |
| spec = spec.div_(w.pow(2).sum()) | |
| spec = torch.view_as_complex(spec).abs().clamp_min(1e-12).log10().mul(10) | |
| vmax = max(0.0, spec.max().item()) | |
| if spec.dim() > 2: | |
| spec = spec.squeeze(0) | |
| ax.clear() | |
| self.specshow( | |
| spec, | |
| ax=ax, | |
| title=title, | |
| xlabel="Time [s]", | |
| ylabel="Frequency [kHz]", | |
| sr=sr, | |
| n_fft=n_fft, | |
| hop=hop, | |
| vmax=vmax, | |
| ) | |
| figure.canvas.draw() | |
| return PILImage.frombytes( | |
| "RGB", | |
| figure.canvas.get_width_height(), | |
| figure.canvas.tostring_rgb() | |
| ) | |
| class FileManager: | |
| """Manages temporary file cleanup""" | |
| def cleanup_tmp(filter_list: List[str] = None, hours_keep: int = 2, temp_dir: str = "/tmp"): | |
| """Clean up old temporary files. | |
| Args: | |
| filter_list: List of file patterns to keep | |
| hours_keep: Number of hours to keep files | |
| temp_dir: Temporary directory path | |
| """ | |
| if filter_list is None: | |
| filter_list = [] | |
| filter_list.append("p232") | |
| if not os.path.exists(temp_dir): | |
| return | |
| logger.info(f"Cleaning up temporary files older than {hours_keep} hours") | |
| cleaned = 0 | |
| for filepath in glob.glob(os.path.join(temp_dir, "*")): | |
| try: | |
| is_old = (time.time() - os.path.getmtime(filepath)) / 3600 > hours_keep | |
| filtered = any(filt in filepath for filt in filter_list if filt is not None) | |
| if is_old and not filtered: | |
| os.remove(filepath) | |
| cleaned += 1 | |
| logger.debug(f"Removed file {filepath}") | |
| except Exception as e: | |
| logger.warning(f"Failed to remove file {filepath}: {e}") | |
| if cleaned > 0: | |
| logger.info(f"Cleaned up {cleaned} temporary files") | |
| # ============================================================================ | |
| # Initialize Application | |
| # ============================================================================ | |
| # Setup configuration | |
| app_config = AppConfig( | |
| device=torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| ) | |
| # Initialize model | |
| logger.info(f"Loading DeepFilterNet2 model on {app_config.device}") | |
| model, df, _ = init_df(app_config.model_path, config_allow_defaults=True) | |
| model = model.to(device=app_config.device).eval() | |
| # Initialize components | |
| audio_processor = AudioProcessor(model, df, app_config) | |
| audio_loader = AudioLoader() | |
| visualizer = SpectrogramVisualizer() | |
| file_manager = FileManager() | |
| # Noise options | |
| NOISES = { | |
| "None": None, | |
| "Kitchen": "samples/dkitchen.wav", | |
| "Living Room": "samples/dliving.wav", | |
| "River": "samples/nriver.wav", | |
| "Cafe": "samples/scafe.wav", | |
| } | |
| # ============================================================================ | |
| # Main Processing Function | |
| # ============================================================================ | |
| def process_audio( | |
| speech_file: Optional[str], | |
| noise_type: str, | |
| snr: int, | |
| mic_input: Optional[str] = None, | |
| ) -> Tuple[str, PILImage.Image, str, PILImage.Image]: | |
| """Main audio processing pipeline. | |
| Args: | |
| speech_file: Path to uploaded audio file | |
| noise_type: Type of background noise to add | |
| snr: Signal-to-noise ratio in dB | |
| mic_input: Path to microphone recording | |
| Returns: | |
| Tuple of (noisy_audio_path, noisy_spectrogram, enhanced_audio_path, enhanced_spectrogram) | |
| """ | |
| try: | |
| # Use mic input if available | |
| if mic_input: | |
| speech_file = mic_input | |
| sr = app_config.sample_rate | |
| logger.info(f"Processing: file={speech_file}, noise={noise_type}, snr={snr}") | |
| # Load input audio | |
| if speech_file is not None: | |
| speech_file = audio_loader.ensure_wav(speech_file) | |
| sample, meta = load_audio(speech_file, sr) | |
| # Limit duration | |
| max_len = app_config.max_duration_seconds * sr | |
| if sample.shape[-1] > max_len: | |
| logger.warning(f"Audio too long, truncating to {app_config.max_duration_seconds}s") | |
| start = torch.randint(0, sample.shape[-1] - max_len, ()).item() | |
| sample = sample[..., start : start + max_len] | |
| else: | |
| # Use default sample | |
| sample, meta = load_audio("samples/p232_013_clean.wav", sr) | |
| sample = sample[..., : app_config.max_duration_seconds * sr] | |
| # Convert to mono if needed | |
| if sample.dim() > 1 and sample.shape[0] > 1: | |
| logger.info(f"Converting from {sample.shape[0]} channels to mono") | |
| sample = sample.mean(dim=0, keepdim=True) | |
| logger.info(f"Loaded audio with shape {sample.shape}") | |
| # Add noise if specified | |
| noise_fn = NOISES.get(noise_type) | |
| if noise_fn is not None: | |
| noise, _ = load_audio(noise_fn, sr) | |
| logger.info(f"Adding {noise_type} noise at {snr} dB SNR") | |
| _, _, sample = audio_processor.mix_at_snr(sample, noise, int(snr)) | |
| # Enhance audio | |
| enhanced = audio_processor.enhance_audio(sample) | |
| logger.info("Audio enhancement completed") | |
| # Resample if needed | |
| if meta.sample_rate != sr and meta.sample_rate > 0: | |
| enhanced = resample(enhanced, sr, meta.sample_rate) | |
| sample = resample(sample, sr, meta.sample_rate) | |
| sr = meta.sample_rate | |
| # Save audio files | |
| noisy_wav = tempfile.NamedTemporaryFile(suffix="_noisy.wav", delete=False).name | |
| save_audio(noisy_wav, sample, sr) | |
| enhanced_wav = tempfile.NamedTemporaryFile(suffix="_enhanced.wav", delete=False).name | |
| save_audio(enhanced_wav, enhanced, sr) | |
| logger.info(f"Saved outputs: {noisy_wav}, {enhanced_wav}") | |
| # Create spectrograms | |
| noisy_spec = visualizer.create_spectrogram( | |
| sample, | |
| visualizer.fig_noisy, | |
| visualizer.ax_noisy, | |
| sr=sr, | |
| title="Noisy Audio Spectrogram" | |
| ) | |
| enhanced_spec = visualizer.create_spectrogram( | |
| enhanced, | |
| visualizer.fig_enh, | |
| visualizer.ax_enh, | |
| sr=sr, | |
| title="Enhanced Audio Spectrogram" | |
| ) | |
| # Cleanup old files | |
| filter_files = [speech_file, noisy_wav, enhanced_wav] | |
| if mic_input: | |
| filter_files.append(mic_input) | |
| file_manager.cleanup_tmp(filter_files, app_config.cleanup_hours) | |
| return noisy_wav, noisy_spec, enhanced_wav, enhanced_spec | |
| except Exception as e: | |
| logger.error(f"Error processing audio: {e}", exc_info=True) | |
| raise gr.Error(f"Processing failed: {str(e)}") | |
| def toggle_input_mode(choice: str): | |
| """Toggle between microphone and file upload.""" | |
| if choice == "mic": | |
| return gr.update(visible=True, value=None), gr.update(visible=False, value=None) | |
| else: | |
| return gr.update(visible=False, value=None), gr.update(visible=True, value=None) | |
| # ============================================================================ | |
| # Gradio Interface | |
| # ============================================================================ | |
| with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
| gr.Markdown( | |
| """ | |
| # 🎵 DeepFilterNet2 Audio Denoising Demo | |
| Remove background noise from your audio recordings using state-of-the-art deep learning. | |
| Upload an audio file or record directly, optionally add synthetic noise, and enhance the quality. | |
| **Features:** | |
| - Support for multiple audio formats (MP3, WAV, M4A, OGG, FLAC) | |
| - Real-time microphone recording | |
| - Customizable background noise injection | |
| - Visual spectrogram comparison | |
| - Up to 1 hour of audio processing | |
| """ | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| gr.Markdown("### Input Settings") | |
| input_mode = gr.Radio( | |
| ["file", "mic"], | |
| value="file", | |
| label="Input Method", | |
| info="Choose how to provide your audio" | |
| ) | |
| audio_file = gr.Audio( | |
| type="filepath", | |
| label="Upload Audio File", | |
| visible=True | |
| ) | |
| mic_input = gr.Audio( | |
| sources=["microphone"], | |
| type="filepath", | |
| label="Record Audio", | |
| visible=False | |
| ) | |
| gr.Markdown("### Enhancement Settings") | |
| noise_type = gr.Dropdown( | |
| label="Background Noise Type", | |
| choices=list(NOISES.keys()), | |
| value="None", | |
| info="Add synthetic background noise for testing" | |
| ) | |
| snr = gr.Dropdown( | |
| label="Signal-to-Noise Ratio (dB)", | |
| choices=["-5", "0", "10", "20"], | |
| value="10", | |
| info="Higher values = less noise" | |
| ) | |
| process_btn = gr.Button("🚀 Denoise Audio", variant="primary", size="lg") | |
| with gr.Column(scale=2): | |
| gr.Markdown("### Results") | |
| with gr.Tab("Noisy Audio"): | |
| noisy_audio = gr.Audio(type="filepath", label="Noisy Audio") | |
| noisy_spec = gr.Image(label="Noisy Spectrogram") | |
| with gr.Tab("Enhanced Audio"): | |
| enhanced_audio = gr.Audio(type="filepath", label="Enhanced Audio") | |
| enhanced_spec = gr.Image(label="Enhanced Spectrogram") | |
| # Examples | |
| gr.Markdown("### 📝 Example Inputs") | |
| gr.Examples( | |
| examples=[ | |
| ["./samples/p232_013_clean.wav", "Kitchen", "10"], | |
| ["./samples/p232_013_clean.wav", "Cafe", "10"], | |
| ["./samples/p232_019_clean.wav", "Cafe", "10"], | |
| ["./samples/p232_019_clean.wav", "River", "10"], | |
| ], | |
| inputs=[audio_file, noise_type, snr], | |
| outputs=[noisy_audio, noisy_spec, enhanced_audio, enhanced_spec], | |
| fn=process_audio, | |
| cache_examples=True, | |
| label="Try these examples", | |
| ) | |
| # Information | |
| gr.Markdown( | |
| """ | |
| ### ℹ️ How It Works | |
| 1. **Upload or Record**: Choose your input method and provide audio | |
| 2. **Configure** (Optional): Add synthetic noise for testing the denoiser | |
| 3. **Process**: Click "Denoise Audio" to enhance your recording | |
| 4. **Compare**: View spectrograms and listen to before/after results | |
| ### 📊 Technical Details | |
| - **Model**: DeepFilterNet2 - Real-time speech enhancement | |
| - **Max Duration**: 1 hour per file | |
| - **Sample Rate**: 48 kHz | |
| - **Supported Formats**: WAV, MP3, M4A, OGG, FLAC, AAC | |
| ### 🎯 Best Results | |
| - Use clear speech recordings | |
| - Avoid extreme clipping or distortion | |
| - For best quality, use WAV format at 48kHz | |
| """ | |
| ) | |
| # Event handlers | |
| process_btn.click( | |
| fn=process_audio, | |
| inputs=[audio_file, noise_type, snr, mic_input], | |
| outputs=[noisy_audio, noisy_spec, enhanced_audio, enhanced_spec], | |
| api_name="denoise", | |
| ) | |
| input_mode.change( | |
| fn=toggle_input_mode, | |
| inputs=input_mode, | |
| outputs=[mic_input, audio_file], | |
| ) | |
| # Initial cleanup | |
| file_manager.cleanup_tmp() | |
| # Launch application | |
| if __name__ == "__main__": | |
| demo.queue().launch( | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| share=False, | |
| ) |