humairawan commited on
Commit
fb24e3e
·
verified ·
1 Parent(s): e54e4dc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +80 -473
app.py CHANGED
@@ -29,319 +29,116 @@ from df.io import resample
29
  class AppConfig:
30
  """Application configuration"""
31
  device: torch.device
32
- sample_rate: int = 48000
33
  max_duration_seconds: int = 3600
34
  cleanup_hours: int = 2
35
  temp_dir: str = "/tmp"
36
  model_path: str = "./DeepFilterNet2"
37
  fade_duration: float = 0.15
38
-
 
 
 
39
 
40
  class AudioProcessor:
41
- """Handles audio processing operations"""
42
-
43
  def __init__(self, model, df, config: AppConfig):
44
  self.model = model
45
  self.df = df
46
  self.config = config
47
-
48
  def mix_at_snr(self, clean: Tensor, noise: Tensor, snr: float, eps: float = 1e-10) -> Tuple[Tensor, Tensor, Tensor]:
49
- """Mix clean and noise signal at a given SNR with improved error handling.
50
-
51
- Args:
52
- clean: 1D Tensor with the clean signal to mix.
53
- noise: 1D Tensor of shape.
54
- snr: Signal to noise ratio in dB.
55
- eps: Small epsilon for numerical stability.
56
-
57
- Returns:
58
- clean: 1D Tensor with gain changed according to the snr.
59
- noise: 1D Tensor with the combined noise channels.
60
- mix: 1D Tensor with added clean and noise signals.
61
- """
62
  clean = torch.as_tensor(clean).mean(0, keepdim=True)
63
  noise = torch.as_tensor(noise).mean(0, keepdim=True)
64
-
65
- # Repeat noise if shorter than clean signal
66
  if noise.shape[1] < clean.shape[1]:
67
  repeats = int(math.ceil(clean.shape[1] / noise.shape[1]))
68
  noise = noise.repeat((1, repeats))
69
-
70
- # Random starting point for noise
71
  max_start = int(noise.shape[1] - clean.shape[1])
72
  start = torch.randint(0, max_start, ()).item() if max_start > 0 else 0
73
- noise = noise[:, start : start + clean.shape[1]]
74
-
75
- # Calculate SNR scaling
76
  E_speech = torch.mean(clean.pow(2)) + eps
77
  E_noise = torch.mean(noise.pow(2)) + eps
78
  K = torch.sqrt((E_noise / E_speech) * 10 ** (snr / 10) + eps)
79
  noise = noise / K
80
  mixture = clean + noise
81
-
82
- # Check for clipping
83
- assert torch.isfinite(mixture).all(), "Non-finite values detected in mixture"
84
  max_m = mixture.abs().max()
85
  if max_m > 1:
86
- logger.warning(f"Clipping detected during mixing. Reducing gain by {1/max_m:.3f}")
87
  clean, noise, mixture = clean / max_m, noise / max_m, mixture / max_m
88
-
89
  return clean, noise, mixture
90
-
91
  def enhance_audio(self, audio: Tensor) -> Tensor:
92
- """Enhance audio using the DeepFilterNet model.
93
-
94
- Args:
95
- audio: Input audio tensor
96
-
97
- Returns:
98
- Enhanced audio tensor
99
- """
100
- logger.info(f"Enhancing audio with shape {audio.shape}")
101
  with torch.no_grad():
102
  enhanced = enhance(self.model, self.df, audio)
103
-
104
- # Apply fade-in to avoid clicks
105
- sr = self.config.sample_rate
106
  fade_samples = int(sr * self.config.fade_duration)
107
  lim = torch.linspace(0.0, 1.0, fade_samples).unsqueeze(0)
108
- lim = torch.cat((lim, torch.ones(1, enhanced.shape[1] - lim.shape[1])), dim=1)
109
- enhanced = enhanced * lim
110
-
111
- return enhanced
112
-
113
 
114
  class AudioLoader:
115
- """Handles audio loading from various sources"""
116
-
117
  @staticmethod
118
  def ensure_wav(filepath: str) -> str:
119
- """Convert audio files to WAV using ffmpeg if needed.
120
-
121
- Args:
122
- filepath: Path to input audio file
123
-
124
- Returns:
125
- Path to WAV file
126
- """
127
  if not filepath:
128
  return filepath
129
-
130
- file_ext = Path(filepath).suffix.lower()
131
- if file_ext in ['.mp3', '.m4a', '.ogg', '.flac', '.aac']:
132
- wav_path = str(Path(filepath).with_suffix('.wav'))
133
- try:
134
- subprocess.run(
135
- ["ffmpeg", "-y", "-i", filepath, "-acodec", "pcm_s16le", wav_path],
136
- check=True,
137
- capture_output=True
138
- )
139
- logger.info(f"Converted {file_ext} to WAV: {wav_path}")
140
- return wav_path
141
- except subprocess.CalledProcessError as e:
142
- logger.error(f"FFmpeg conversion failed: {e.stderr}")
143
- raise
144
  return filepath
145
-
146
  @staticmethod
147
- def load_audio_gradio(
148
- audio_or_file: Union[None, str, Tuple[int, np.ndarray]],
149
- sr: int
150
- ) -> Optional[Tuple[Tensor, AudioMetaData]]:
151
- """Load audio from Gradio input (file path or recorded audio).
152
-
153
- Args:
154
- audio_or_file: Either a file path string or tuple of (sample_rate, audio_array)
155
- sr: Target sample rate
156
-
157
- Returns:
158
- Tuple of (audio tensor, metadata) or None
159
- """
160
- if audio_or_file is None:
161
- return None
162
-
163
  if isinstance(audio_or_file, str):
164
- if audio_or_file.lower() == "none":
165
- return None
166
- # Load from file
167
  audio_or_file = AudioLoader.ensure_wav(audio_or_file)
168
- audio, meta = load_audio(audio_or_file, sr)
169
  else:
170
- # Load from Gradio recording
171
- meta = AudioMetaData(-1, -1, -1, -1, "")
172
- assert isinstance(audio_or_file, (tuple, list))
173
- meta.sample_rate, audio_np = audio_or_file
174
-
175
- # Handle different array shapes
176
  audio_np = audio_np.reshape(audio_np.shape[0], -1).T
177
-
178
- # Convert to float32
179
  if audio_np.dtype == np.int16:
180
  audio_np = (audio_np / (1 << 15)).astype(np.float32)
181
  elif audio_np.dtype == np.int32:
182
  audio_np = (audio_np / (1 << 31)).astype(np.float32)
183
-
184
- audio = resample(torch.from_numpy(audio_np), meta.sample_rate, sr)
185
-
186
- return audio, meta
187
-
188
 
189
  class SpectrogramVisualizer:
190
- """Handles spectrogram visualization"""
191
-
192
- def __init__(self, figsize: Tuple[float, float] = (15.2, 4)):
193
  self.figsize = figsize
194
  self.fig_noisy, self.ax_noisy = plt.subplots(figsize=figsize)
195
- self.fig_noisy.set_tight_layout(True)
196
  self.fig_enh, self.ax_enh = plt.subplots(figsize=figsize)
197
- self.fig_enh.set_tight_layout(True)
198
-
199
- def specshow(
200
- self,
201
- spec: Union[Tensor, np.ndarray],
202
- ax: Optional[plt.Axes] = None,
203
- title: Optional[str] = None,
204
- xlabel: Optional[str] = None,
205
- ylabel: Optional[str] = None,
206
- sr: int = 48000,
207
- n_fft: Optional[int] = None,
208
- hop: Optional[int] = None,
209
- vmin: float = -100,
210
- vmax: float = 0,
211
- cmap: str = "inferno",
212
- ):
213
- """Plot a spectrogram of shape [F, T]"""
214
- spec_np = spec.cpu().numpy() if isinstance(spec, torch.Tensor) else spec
215
-
216
- if n_fft is None:
217
- n_fft = spec.shape[0] * 2 if spec.shape[0] % 2 == 0 else (spec.shape[0] - 1) * 2
218
- hop = hop or n_fft // 4
219
-
220
- t = np.arange(0, spec_np.shape[-1]) * hop / sr
221
- f = np.arange(0, spec_np.shape[0]) * sr // 2 / (n_fft // 2) / 1000
222
-
223
- im = ax.pcolormesh(
224
- t, f, spec_np,
225
- rasterized=True,
226
- shading="auto",
227
- vmin=vmin,
228
- vmax=vmax,
229
- cmap=cmap
230
- )
231
-
232
- if title:
233
- ax.set_title(title)
234
- if xlabel:
235
- ax.set_xlabel(xlabel)
236
- if ylabel:
237
- ax.set_ylabel(ylabel)
238
-
239
- return im
240
-
241
- def create_spectrogram(
242
- self,
243
- audio: Tensor,
244
- figure: plt.Figure,
245
- ax: plt.Axes,
246
- sr: int = 48000,
247
- n_fft: int = 1024,
248
- hop: int = 512,
249
- title: Optional[str] = None,
250
- ) -> PILImage.Image:
251
- """Create spectrogram image from audio tensor"""
252
  audio = torch.as_tensor(audio)
253
-
254
- # Compute STFT
255
  w = torch.hann_window(n_fft, device=audio.device)
256
  spec = torch.stft(audio, n_fft, hop, window=w, return_complex=False)
257
  spec = spec.div_(w.pow(2).sum())
258
  spec = torch.view_as_complex(spec).abs().clamp_min(1e-12).log10().mul(10)
259
-
260
- vmax = max(0.0, spec.max().item())
261
-
262
  if spec.dim() > 2:
263
  spec = spec.squeeze(0)
264
-
265
  ax.clear()
266
- self.specshow(
267
- spec,
268
- ax=ax,
269
- title=title,
270
- xlabel="Time [s]",
271
- ylabel="Frequency [kHz]",
272
- sr=sr,
273
- n_fft=n_fft,
274
- hop=hop,
275
- vmax=vmax,
276
- )
277
-
278
  figure.canvas.draw()
279
- return PILImage.frombytes(
280
- "RGB",
281
- figure.canvas.get_width_height(),
282
- figure.canvas.tostring_rgb()
283
- )
284
-
285
-
286
- class FileManager:
287
- """Manages temporary file cleanup"""
288
-
289
- @staticmethod
290
- def cleanup_tmp(filter_list: List[str] = None, hours_keep: int = 2, temp_dir: str = "/tmp"):
291
- """Clean up old temporary files.
292
-
293
- Args:
294
- filter_list: List of file patterns to keep
295
- hours_keep: Number of hours to keep files
296
- temp_dir: Temporary directory path
297
- """
298
- if filter_list is None:
299
- filter_list = []
300
- filter_list.append("p232")
301
-
302
- if not os.path.exists(temp_dir):
303
- return
304
-
305
- logger.info(f"Cleaning up temporary files older than {hours_keep} hours")
306
- cleaned = 0
307
-
308
- for filepath in glob.glob(os.path.join(temp_dir, "*")):
309
- try:
310
- is_old = (time.time() - os.path.getmtime(filepath)) / 3600 > hours_keep
311
- filtered = any(filt in filepath for filt in filter_list if filt is not None)
312
-
313
- if is_old and not filtered:
314
- os.remove(filepath)
315
- cleaned += 1
316
- logger.debug(f"Removed file {filepath}")
317
- except Exception as e:
318
- logger.warning(f"Failed to remove file {filepath}: {e}")
319
-
320
- if cleaned > 0:
321
- logger.info(f"Cleaned up {cleaned} temporary files")
322
-
323
 
324
  # ============================================================================
325
- # Initialize Application
326
  # ============================================================================
327
 
328
- # Setup configuration
329
- app_config = AppConfig(
330
- device=torch.device("cuda" if torch.cuda.is_available() else "cpu")
331
- )
332
-
333
- # Initialize model
334
- logger.info(f"Loading DeepFilterNet2 model on {app_config.device}")
335
  model, df, _ = init_df(app_config.model_path, config_allow_defaults=True)
336
  model = model.to(device=app_config.device).eval()
337
-
338
- # Initialize components
339
  audio_processor = AudioProcessor(model, df, app_config)
340
  audio_loader = AudioLoader()
341
  visualizer = SpectrogramVisualizer()
342
- file_manager = FileManager()
343
 
344
- # Noise options
345
  NOISES = {
346
  "None": None,
347
  "Kitchen": "samples/dkitchen.wav",
@@ -350,7 +147,6 @@ NOISES = {
350
  "Cafe": "samples/scafe.wav",
351
  }
352
 
353
-
354
  # ============================================================================
355
  # Main Processing Function
356
  # ============================================================================
@@ -359,250 +155,61 @@ def process_audio(
359
  speech_file: Optional[str],
360
  noise_type: str,
361
  snr: int,
 
362
  mic_input: Optional[str] = None,
363
  ) -> Tuple[str, PILImage.Image, str, PILImage.Image]:
364
- """Main audio processing pipeline.
365
-
366
- Args:
367
- speech_file: Path to uploaded audio file
368
- noise_type: Type of background noise to add
369
- snr: Signal-to-noise ratio in dB
370
- mic_input: Path to microphone recording
371
-
372
- Returns:
373
- Tuple of (noisy_audio_path, noisy_spectrogram, enhanced_audio_path, enhanced_spectrogram)
374
- """
375
- try:
376
- # Use mic input if available
377
- if mic_input:
378
- speech_file = mic_input
379
-
380
- sr = app_config.sample_rate
381
- logger.info(f"Processing: file={speech_file}, noise={noise_type}, snr={snr}")
382
-
383
- # Load input audio
384
- if speech_file is not None:
385
- speech_file = audio_loader.ensure_wav(speech_file)
386
- sample, meta = load_audio(speech_file, sr)
387
-
388
- # Limit duration
389
- max_len = app_config.max_duration_seconds * sr
390
- if sample.shape[-1] > max_len:
391
- logger.warning(f"Audio too long, truncating to {app_config.max_duration_seconds}s")
392
- start = torch.randint(0, sample.shape[-1] - max_len, ()).item()
393
- sample = sample[..., start : start + max_len]
394
- else:
395
- # Use default sample
396
- sample, meta = load_audio("samples/p232_013_clean.wav", sr)
397
- sample = sample[..., : app_config.max_duration_seconds * sr]
398
-
399
- # Convert to mono if needed
400
- if sample.dim() > 1 and sample.shape[0] > 1:
401
- logger.info(f"Converting from {sample.shape[0]} channels to mono")
402
- sample = sample.mean(dim=0, keepdim=True)
403
-
404
- logger.info(f"Loaded audio with shape {sample.shape}")
405
-
406
- # Add noise if specified
407
- noise_fn = NOISES.get(noise_type)
408
- if noise_fn is not None:
409
- noise, _ = load_audio(noise_fn, sr)
410
- logger.info(f"Adding {noise_type} noise at {snr} dB SNR")
411
- _, _, sample = audio_processor.mix_at_snr(sample, noise, int(snr))
412
-
413
- # Enhance audio
414
- enhanced = audio_processor.enhance_audio(sample)
415
- logger.info("Audio enhancement completed")
416
-
417
- # Resample if needed
418
- if meta.sample_rate != sr and meta.sample_rate > 0:
419
- enhanced = resample(enhanced, sr, meta.sample_rate)
420
- sample = resample(sample, sr, meta.sample_rate)
421
- sr = meta.sample_rate
422
-
423
- # Save audio files
424
- noisy_wav = tempfile.NamedTemporaryFile(suffix="_noisy.wav", delete=False).name
425
- save_audio(noisy_wav, sample, sr)
426
-
427
- enhanced_wav = tempfile.NamedTemporaryFile(suffix="_enhanced.wav", delete=False).name
428
- save_audio(enhanced_wav, enhanced, sr)
429
-
430
- logger.info(f"Saved outputs: {noisy_wav}, {enhanced_wav}")
431
-
432
- # Create spectrograms
433
- noisy_spec = visualizer.create_spectrogram(
434
- sample,
435
- visualizer.fig_noisy,
436
- visualizer.ax_noisy,
437
- sr=sr,
438
- title="Noisy Audio Spectrogram"
439
- )
440
-
441
- enhanced_spec = visualizer.create_spectrogram(
442
- enhanced,
443
- visualizer.fig_enh,
444
- visualizer.ax_enh,
445
- sr=sr,
446
- title="Enhanced Audio Spectrogram"
447
- )
448
-
449
- # Cleanup old files
450
- filter_files = [speech_file, noisy_wav, enhanced_wav]
451
- if mic_input:
452
- filter_files.append(mic_input)
453
- file_manager.cleanup_tmp(filter_files, app_config.cleanup_hours)
454
-
455
- return noisy_wav, noisy_spec, enhanced_wav, enhanced_spec
456
-
457
- except Exception as e:
458
- logger.error(f"Error processing audio: {e}", exc_info=True)
459
- raise gr.Error(f"Processing failed: {str(e)}")
460
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
461
 
462
- def toggle_input_mode(choice: str):
463
- """Toggle between microphone and file upload."""
464
- if choice == "mic":
465
- return gr.update(visible=True, value=None), gr.update(visible=False, value=None)
466
- else:
467
- return gr.update(visible=False, value=None), gr.update(visible=True, value=None)
468
 
 
 
 
 
 
469
 
470
  # ============================================================================
471
  # Gradio Interface
472
  # ============================================================================
473
 
474
- with gr.Blocks(theme=gr.themes.Soft()) as demo:
475
- gr.Markdown(
476
- """
477
- # 🎵 DeepFilterNet2 Audio Denoising Demo
478
-
479
- Remove background noise from your audio recordings using state-of-the-art deep learning.
480
- Upload an audio file or record directly, optionally add synthetic noise, and enhance the quality.
481
-
482
- **Features:**
483
- - Support for multiple audio formats (MP3, WAV, M4A, OGG, FLAC)
484
- - Real-time microphone recording
485
- - Customizable background noise injection
486
- - Visual spectrogram comparison
487
- - Up to 1 hour of audio processing
488
- """
489
- )
490
-
491
- with gr.Row():
492
- with gr.Column(scale=1):
493
- gr.Markdown("### Input Settings")
494
-
495
- input_mode = gr.Radio(
496
- ["file", "mic"],
497
- value="file",
498
- label="Input Method",
499
- info="Choose how to provide your audio"
500
- )
501
-
502
- audio_file = gr.Audio(
503
- type="filepath",
504
- label="Upload Audio File",
505
- visible=True
506
- )
507
-
508
- mic_input = gr.Audio(
509
- sources=["microphone"],
510
- type="filepath",
511
- label="Record Audio",
512
- visible=False
513
- )
514
-
515
- gr.Markdown("### Enhancement Settings")
516
-
517
- noise_type = gr.Dropdown(
518
- label="Background Noise Type",
519
- choices=list(NOISES.keys()),
520
- value="None",
521
- info="Add synthetic background noise for testing"
522
- )
523
-
524
- snr = gr.Dropdown(
525
- label="Signal-to-Noise Ratio (dB)",
526
- choices=["-5", "0", "10", "20"],
527
- value="10",
528
- info="Higher values = less noise"
529
- )
530
-
531
- process_btn = gr.Button("🚀 Denoise Audio", variant="primary", size="lg")
532
-
533
- with gr.Column(scale=2):
534
- gr.Markdown("### Results")
535
-
536
- with gr.Tab("Noisy Audio"):
537
- noisy_audio = gr.Audio(type="filepath", label="Noisy Audio")
538
- noisy_spec = gr.Image(label="Noisy Spectrogram")
539
-
540
- with gr.Tab("Enhanced Audio"):
541
- enhanced_audio = gr.Audio(type="filepath", label="Enhanced Audio")
542
- enhanced_spec = gr.Image(label="Enhanced Spectrogram")
543
-
544
- # Examples
545
- gr.Markdown("### 📝 Example Inputs")
546
- gr.Examples(
547
- examples=[
548
- ["./samples/p232_013_clean.wav", "Kitchen", "10"],
549
- ["./samples/p232_013_clean.wav", "Cafe", "10"],
550
- ["./samples/p232_019_clean.wav", "Cafe", "10"],
551
- ["./samples/p232_019_clean.wav", "River", "10"],
552
- ],
553
- inputs=[audio_file, noise_type, snr],
554
- outputs=[noisy_audio, noisy_spec, enhanced_audio, enhanced_spec],
555
- fn=process_audio,
556
- cache_examples=True,
557
- label="Try these examples",
558
- )
559
-
560
- # Information
561
- gr.Markdown(
562
- """
563
- ### ℹ️ How It Works
564
-
565
- 1. **Upload or Record**: Choose your input method and provide audio
566
- 2. **Configure** (Optional): Add synthetic noise for testing the denoiser
567
- 3. **Process**: Click "Denoise Audio" to enhance your recording
568
- 4. **Compare**: View spectrograms and listen to before/after results
569
-
570
- ### 📊 Technical Details
571
-
572
- - **Model**: DeepFilterNet2 - Real-time speech enhancement
573
- - **Max Duration**: 1 hour per file
574
- - **Sample Rate**: 48 kHz
575
- - **Supported Formats**: WAV, MP3, M4A, OGG, FLAC, AAC
576
-
577
- ### 🎯 Best Results
578
-
579
- - Use clear speech recordings
580
- - Avoid extreme clipping or distortion
581
- - For best quality, use WAV format at 48kHz
582
- """
583
- )
584
-
585
- # Event handlers
586
  process_btn.click(
587
  fn=process_audio,
588
- inputs=[audio_file, noise_type, snr, mic_input],
589
- outputs=[noisy_audio, noisy_spec, enhanced_audio, enhanced_spec],
590
- api_name="denoise",
591
- )
592
-
593
- input_mode.change(
594
- fn=toggle_input_mode,
595
- inputs=input_mode,
596
- outputs=[mic_input, audio_file],
597
  )
598
 
599
- # Initial cleanup
600
- file_manager.cleanup_tmp()
601
-
602
- # Launch application
603
  if __name__ == "__main__":
604
- demo.queue().launch(
605
- server_name="0.0.0.0",
606
- server_port=7860,
607
- share=False,
608
- )
 
29
  class AppConfig:
30
  """Application configuration"""
31
  device: torch.device
32
+ model_sample_rate: int = 48000
33
  max_duration_seconds: int = 3600
34
  cleanup_hours: int = 2
35
  temp_dir: str = "/tmp"
36
  model_path: str = "./DeepFilterNet2"
37
  fade_duration: float = 0.15
38
+
39
+ # ============================================================================
40
+ # Audio Processing Classes
41
+ # ============================================================================
42
 
43
  class AudioProcessor:
 
 
44
  def __init__(self, model, df, config: AppConfig):
45
  self.model = model
46
  self.df = df
47
  self.config = config
48
+
49
  def mix_at_snr(self, clean: Tensor, noise: Tensor, snr: float, eps: float = 1e-10) -> Tuple[Tensor, Tensor, Tensor]:
 
 
 
 
 
 
 
 
 
 
 
 
 
50
  clean = torch.as_tensor(clean).mean(0, keepdim=True)
51
  noise = torch.as_tensor(noise).mean(0, keepdim=True)
 
 
52
  if noise.shape[1] < clean.shape[1]:
53
  repeats = int(math.ceil(clean.shape[1] / noise.shape[1]))
54
  noise = noise.repeat((1, repeats))
 
 
55
  max_start = int(noise.shape[1] - clean.shape[1])
56
  start = torch.randint(0, max_start, ()).item() if max_start > 0 else 0
57
+ noise = noise[:, start:start + clean.shape[1]]
 
 
58
  E_speech = torch.mean(clean.pow(2)) + eps
59
  E_noise = torch.mean(noise.pow(2)) + eps
60
  K = torch.sqrt((E_noise / E_speech) * 10 ** (snr / 10) + eps)
61
  noise = noise / K
62
  mixture = clean + noise
 
 
 
63
  max_m = mixture.abs().max()
64
  if max_m > 1:
 
65
  clean, noise, mixture = clean / max_m, noise / max_m, mixture / max_m
 
66
  return clean, noise, mixture
67
+
68
  def enhance_audio(self, audio: Tensor) -> Tensor:
 
 
 
 
 
 
 
 
 
69
  with torch.no_grad():
70
  enhanced = enhance(self.model, self.df, audio)
71
+ sr = self.config.model_sample_rate
 
 
72
  fade_samples = int(sr * self.config.fade_duration)
73
  lim = torch.linspace(0.0, 1.0, fade_samples).unsqueeze(0)
74
+ lim = torch.cat((lim, torch.ones(1, enhanced.shape[1] - fade_samples)), dim=1)
75
+ return enhanced * lim
 
 
 
76
 
77
  class AudioLoader:
 
 
78
  @staticmethod
79
  def ensure_wav(filepath: str) -> str:
 
 
 
 
 
 
 
 
80
  if not filepath:
81
  return filepath
82
+ ext = Path(filepath).suffix.lower()
83
+ if ext in [".mp3", ".m4a", ".ogg", ".flac", ".aac"]:
84
+ wav_path = str(Path(filepath).with_suffix(".wav"))
85
+ subprocess.run(["ffmpeg", "-y", "-i", filepath, "-acodec", "pcm_s16le", wav_path],
86
+ check=True, capture_output=True)
87
+ return wav_path
 
 
 
 
 
 
 
 
 
88
  return filepath
89
+
90
  @staticmethod
91
+ def load_and_resample(audio_or_file: Union[str, Tuple[int, np.ndarray]], target_sr: int) -> Tensor:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
  if isinstance(audio_or_file, str):
 
 
 
93
  audio_or_file = AudioLoader.ensure_wav(audio_or_file)
94
+ audio, meta = load_audio(audio_or_file, target_sr)
95
  else:
96
+ sr, audio_np = audio_or_file
 
 
 
 
 
97
  audio_np = audio_np.reshape(audio_np.shape[0], -1).T
 
 
98
  if audio_np.dtype == np.int16:
99
  audio_np = (audio_np / (1 << 15)).astype(np.float32)
100
  elif audio_np.dtype == np.int32:
101
  audio_np = (audio_np / (1 << 31)).astype(np.float32)
102
+ audio = torch.from_numpy(audio_np)
103
+ if sr != target_sr:
104
+ audio = resample(audio, target_sr, sr)
105
+ return audio
 
106
 
107
  class SpectrogramVisualizer:
108
+ def __init__(self, figsize=(15,4)):
 
 
109
  self.figsize = figsize
110
  self.fig_noisy, self.ax_noisy = plt.subplots(figsize=figsize)
 
111
  self.fig_enh, self.ax_enh = plt.subplots(figsize=figsize)
112
+
113
+ def create_spectrogram(self, audio: Tensor, figure: plt.Figure, ax: plt.Axes,
114
+ sr: int, n_fft: int = 1024, hop: int = 512, title: str = None) -> PILImage.Image:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
115
  audio = torch.as_tensor(audio)
 
 
116
  w = torch.hann_window(n_fft, device=audio.device)
117
  spec = torch.stft(audio, n_fft, hop, window=w, return_complex=False)
118
  spec = spec.div_(w.pow(2).sum())
119
  spec = torch.view_as_complex(spec).abs().clamp_min(1e-12).log10().mul(10)
 
 
 
120
  if spec.dim() > 2:
121
  spec = spec.squeeze(0)
 
122
  ax.clear()
123
+ t = np.arange(spec.shape[-1]) * hop / sr
124
+ f = np.arange(spec.shape[0]) * sr // 2 / (n_fft // 2) / 1000
125
+ ax.pcolormesh(t, f, spec.cpu().numpy(), shading="auto", cmap="inferno", vmin=-100, vmax=0)
126
+ if title:
127
+ ax.set_title(title)
 
 
 
 
 
 
 
128
  figure.canvas.draw()
129
+ return PILImage.frombytes("RGB", figure.canvas.get_width_height(), figure.canvas.tostring_rgb())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
130
 
131
  # ============================================================================
132
+ # Initialization
133
  # ============================================================================
134
 
135
+ app_config = AppConfig(device=torch.device("cuda" if torch.cuda.is_available() else "cpu"))
 
 
 
 
 
 
136
  model, df, _ = init_df(app_config.model_path, config_allow_defaults=True)
137
  model = model.to(device=app_config.device).eval()
 
 
138
  audio_processor = AudioProcessor(model, df, app_config)
139
  audio_loader = AudioLoader()
140
  visualizer = SpectrogramVisualizer()
 
141
 
 
142
  NOISES = {
143
  "None": None,
144
  "Kitchen": "samples/dkitchen.wav",
 
147
  "Cafe": "samples/scafe.wav",
148
  }
149
 
 
150
  # ============================================================================
151
  # Main Processing Function
152
  # ============================================================================
 
155
  speech_file: Optional[str],
156
  noise_type: str,
157
  snr: int,
158
+ target_rate: int = 22050,
159
  mic_input: Optional[str] = None,
160
  ) -> Tuple[str, PILImage.Image, str, PILImage.Image]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
161
 
162
+ if mic_input:
163
+ speech_file = mic_input
164
+ model_sr = app_config.model_sample_rate
165
+ audio = audio_loader.load_and_resample(speech_file, model_sr)
166
+
167
+ # Add noise if requested
168
+ noise_fn = NOISES.get(noise_type)
169
+ if noise_fn:
170
+ noise_audio = audio_loader.load_and_resample(noise_fn, model_sr)
171
+ _, _, audio = audio_processor.mix_at_snr(audio, noise_audio, snr)
172
+
173
+ enhanced = audio_processor.enhance_audio(audio)
174
+
175
+ # Downsample back to target rate if needed
176
+ if target_rate != model_sr:
177
+ enhanced = resample(enhanced, target_rate, model_sr)
178
+ audio = resample(audio, target_rate, model_sr)
179
 
180
+ noisy_wav = tempfile.NamedTemporaryFile(suffix="_noisy.wav", delete=False).name
181
+ enhanced_wav = tempfile.NamedTemporaryFile(suffix="_enhanced.wav", delete=False).name
182
+ save_audio(noisy_wav, audio, target_rate)
183
+ save_audio(enhanced_wav, enhanced, target_rate)
 
 
184
 
185
+ noisy_spec = visualizer.create_spectrogram(audio, visualizer.fig_noisy, visualizer.ax_noisy,
186
+ sr=target_rate, title="Noisy Audio")
187
+ enhanced_spec = visualizer.create_spectrogram(enhanced, visualizer.fig_enh, visualizer.ax_enh,
188
+ sr=target_rate, title="Enhanced Audio")
189
+ return noisy_wav, noisy_spec, enhanced_wav, enhanced_spec
190
 
191
  # ============================================================================
192
  # Gradio Interface
193
  # ============================================================================
194
 
195
+ with gr.Blocks() as demo:
196
+ gr.Markdown("# 🎵 DeepFilterNet2 Denoiser with Resampling Support")
197
+ audio_file = gr.Audio(type="filepath", label="Upload Audio")
198
+ mic_input = gr.Audio(sources=["microphone"], type="filepath", label="Record Audio")
199
+ noise_type = gr.Dropdown(label="Noise Type", choices=list(NOISES.keys()), value="None")
200
+ snr = gr.Slider(label="SNR (dB)", minimum=-10, maximum=30, step=1, value=10)
201
+ target_rate = gr.Dropdown(label="Output Sample Rate", choices=[16000, 22050, 44100, 48000], value=22050)
202
+ process_btn = gr.Button("🚀 Enhance Audio")
203
+ noisy_audio = gr.Audio(type="filepath")
204
+ noisy_spec = gr.Image()
205
+ enhanced_audio = gr.Audio(type="filepath")
206
+ enhanced_spec = gr.Image()
207
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
208
  process_btn.click(
209
  fn=process_audio,
210
+ inputs=[audio_file, noise_type, snr, target_rate, mic_input],
211
+ outputs=[noisy_audio, noisy_spec, enhanced_audio, enhanced_spec]
 
 
 
 
 
 
 
212
  )
213
 
 
 
 
 
214
  if __name__ == "__main__":
215
+ demo.launch()