humair025 commited on
Commit
4b1cd22
·
verified ·
1 Parent(s): 73abbbc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +554 -311
app.py CHANGED
@@ -3,10 +3,10 @@ import math
3
  import os
4
  import tempfile
5
  import time
 
6
  from typing import List, Optional, Tuple, Union
7
-
8
  import subprocess
9
- # import os
10
 
11
  import gradio as gr
12
  import matplotlib.pyplot as plt
@@ -21,19 +21,327 @@ from df import config
21
  from df.enhance import enhance, init_df, load_audio, save_audio
22
  from df.io import resample
23
 
24
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
25
- model, df, _ = init_df("./DeepFilterNet2", config_allow_defaults=True)
26
- model = model.to(device=device).eval()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
 
28
- fig_noisy: plt.Figure
29
- fig_enh: plt.Figure
30
- ax_noisy: plt.Axes
31
- ax_enh: plt.Axes
32
- fig_noisy, ax_noisy = plt.subplots(figsize=(15.2, 4))
33
- fig_noisy.set_tight_layout(True)
34
- fig_enh, ax_enh = plt.subplots(figsize=(15.2, 4))
35
- fig_enh.set_tight_layout(True)
36
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
  NOISES = {
38
  "None": None,
39
  "Kitchen": "samples/dkitchen.wav",
@@ -43,323 +351,258 @@ NOISES = {
43
  }
44
 
45
 
46
- def mix_at_snr(clean, noise, snr, eps=1e-10):
47
- """Mix clean and noise signal at a given SNR.
 
48
 
 
 
 
 
 
 
 
 
49
  Args:
50
- clean: 1D Tensor with the clean signal to mix.
51
- noise: 1D Tensor of shape.
52
- snr: Signal to noise ratio.
53
-
 
54
  Returns:
55
- clean: 1D Tensor with gain changed according to the snr.
56
- noise: 1D Tensor with the combined noise channels.
57
- mix: 1D Tensor with added clean and noise signals.
58
-
59
  """
60
- clean = torch.as_tensor(clean).mean(0, keepdim=True)
61
- noise = torch.as_tensor(noise).mean(0, keepdim=True)
62
- if noise.shape[1] < clean.shape[1]:
63
- noise = noise.repeat((1, int(math.ceil(clean.shape[1] / noise.shape[1]))))
64
- max_start = int(noise.shape[1] - clean.shape[1])
65
- start = torch.randint(0, max_start, ()).item() if max_start > 0 else 0
66
- logger.debug(f"start: {start}, {clean.shape}")
67
- noise = noise[:, start : start + clean.shape[1]]
68
- E_speech = torch.mean(clean.pow(2)) + eps
69
- E_noise = torch.mean(noise.pow(2))
70
- K = torch.sqrt((E_noise / E_speech) * 10 ** (snr / 10) + eps)
71
- noise = noise / K
72
- mixture = clean + noise
73
- logger.debug("mixture: {mixture.shape}")
74
- assert torch.isfinite(mixture).all()
75
- max_m = mixture.abs().max()
76
- if max_m > 1:
77
- logger.warning(f"Clipping detected during mixing. Reducing gain by {1/max_m}")
78
- clean, noise, mixture = clean / max_m, noise / max_m, mixture / max_m
79
- return clean, noise, mixture
80
-
81
-
82
- def load_audio_gradio(
83
- audio_or_file: Union[None, str, Tuple[int, np.ndarray]], sr: int
84
- ) -> Optional[Tuple[Tensor, AudioMetaData]]:
85
- if audio_or_file is None:
86
- return None
87
- if isinstance(audio_or_file, str):
88
- if audio_or_file.lower() == "none":
89
- return None
90
- # First try default format
91
- audio, meta = load_audio(audio_or_file, sr)
92
- else:
93
- meta = AudioMetaData(-1, -1, -1, -1, "")
94
- assert isinstance(audio_or_file, (tuple, list))
95
- meta.sample_rate, audio_np = audio_or_file
96
- # Gradio documentation says, the shape is [samples, 2], but apparently sometimes its not.
97
- audio_np = audio_np.reshape(audio_np.shape[0], -1).T
98
- if audio_np.dtype == np.int16:
99
- audio_np = (audio_np / (1 << 15)).astype(np.float32)
100
- elif audio_np.dtype == np.int32:
101
- audio_np = (audio_np / (1 << 31)).astype(np.float32)
102
- audio = resample(torch.from_numpy(audio_np), meta.sample_rate, sr)
103
- return audio, meta
104
-
105
-
106
- def ensure_wav(filepath: str) -> str:
107
- """Convert MP3 (or other formats) to WAV using ffmpeg if needed."""
108
- if filepath.lower().endswith(".mp3"):
109
- wav_path = filepath.rsplit(".", 1)[0] + ".wav"
110
- subprocess.run(["ffmpeg", "-y", "-i", filepath, wav_path], check=True)
111
- return wav_path
112
- return filepath
113
-
114
-
115
- def demo_fn(speech_upl: str, noise_type: str, snr: int, mic_input: Optional[str] = None):
116
- if mic_input:
117
- speech_upl = mic_input
118
-
119
- sr = config("sr", 48000, int, section="df")
120
- logger.info(f"Got parameters speech_upl: {speech_upl}, noise: {noise_type}, snr: {snr}")
121
- snr = int(snr)
122
- noise_fn = NOISES[noise_type]
123
- meta = AudioMetaData(-1, -1, -1, -1, "")
124
-
125
- max_s = 3600 # allow up to 1 hour (3600 seconds)
126
-
127
- if speech_upl is not None:
128
- # ✅ Ensure compatible WAV input
129
- speech_upl = ensure_wav(speech_upl)
130
-
131
- sample, meta = load_audio(speech_upl, sr)
132
- max_len = max_s * sr
133
- if sample.shape[-1] > max_len:
134
- start = torch.randint(0, sample.shape[-1] - max_len, ()).item()
135
- sample = sample[..., start : start + max_len]
136
- else:
137
- sample, meta = load_audio("samples/p232_013_clean.wav", sr)
138
- sample = sample[..., : max_s * sr]
139
-
140
- if sample.dim() > 1 and sample.shape[0] > 1:
141
- assert sample.shape[1] > sample.shape[0], f"Expecting channels first, but got {sample.shape}"
142
- sample = sample.mean(dim=0, keepdim=True)
143
-
144
- logger.info(f"Loaded sample with shape {sample.shape}")
145
-
146
- if noise_fn is not None:
147
- noise, _ = load_audio(noise_fn, sr) # type: ignore
148
- logger.info(f"Loaded noise with shape {noise.shape}")
149
- _, _, sample = mix_at_snr(sample, noise, snr)
150
-
151
- logger.info("Start denoising audio")
152
- enhanced = enhance(model, df, sample)
153
- logger.info("Denoising finished")
154
-
155
- lim = torch.linspace(0.0, 1.0, int(sr * 0.15)).unsqueeze(0)
156
- lim = torch.cat((lim, torch.ones(1, enhanced.shape[1] - lim.shape[1])), dim=1)
157
- enhanced = enhanced * lim
158
-
159
- if meta.sample_rate != sr:
160
- enhanced = resample(enhanced, sr, meta.sample_rate)
161
- sample = resample(sample, sr, meta.sample_rate)
162
- sr = meta.sample_rate
163
-
164
- noisy_wav = tempfile.NamedTemporaryFile(suffix="noisy.wav", delete=False).name
165
- save_audio(noisy_wav, sample, sr)
166
- enhanced_wav = tempfile.NamedTemporaryFile(suffix="enhanced.wav", delete=False).name
167
- save_audio(enhanced_wav, enhanced, sr)
168
-
169
- logger.info(f"saved audios: {noisy_wav}, {enhanced_wav}")
170
-
171
- ax_noisy.clear()
172
- ax_enh.clear()
173
- noisy_im = spec_im(sample, sr=sr, figure=fig_noisy, ax=ax_noisy)
174
- enh_im = spec_im(enhanced, sr=sr, figure=fig_enh, ax=ax_enh)
175
-
176
- filter = [speech_upl, noisy_wav, enhanced_wav]
177
- if mic_input is not None and mic_input != "":
178
- filter.append(mic_input)
179
- cleanup_tmp(filter)
180
-
181
- return noisy_wav, noisy_im, enhanced_wav, enh_im
182
-
183
- def specshow(
184
- spec,
185
- ax=None,
186
- title=None,
187
- xlabel=None,
188
- ylabel=None,
189
- sr=48000,
190
- n_fft=None,
191
- hop=None,
192
- t=None,
193
- f=None,
194
- vmin=-100,
195
- vmax=0,
196
- xlim=None,
197
- ylim=None,
198
- cmap="inferno",
199
- ):
200
- """Plots a spectrogram of shape [F, T]"""
201
- spec_np = spec.cpu().numpy() if isinstance(spec, torch.Tensor) else spec
202
- if ax is not None:
203
- set_title = ax.set_title
204
- set_xlabel = ax.set_xlabel
205
- set_ylabel = ax.set_ylabel
206
- set_xlim = ax.set_xlim
207
- set_ylim = ax.set_ylim
208
- else:
209
- ax = plt
210
- set_title = plt.title
211
- set_xlabel = plt.xlabel
212
- set_ylabel = plt.ylabel
213
- set_xlim = plt.xlim
214
- set_ylim = plt.ylim
215
- if n_fft is None:
216
- if spec.shape[0] % 2 == 0:
217
- n_fft = spec.shape[0] * 2
218
  else:
219
- n_fft = (spec.shape[0] - 1) * 2
220
- hop = hop or n_fft // 4
221
- if t is None:
222
- t = np.arange(0, spec_np.shape[-1]) * hop / sr
223
- if f is None:
224
- f = np.arange(0, spec_np.shape[0]) * sr // 2 / (n_fft // 2) / 1000
225
- im = ax.pcolormesh(
226
- t, f, spec_np, rasterized=True, shading="auto", vmin=vmin, vmax=vmax, cmap=cmap
227
- )
228
- if title is not None:
229
- set_title(title)
230
- if xlabel is not None:
231
- set_xlabel(xlabel)
232
- if ylabel is not None:
233
- set_ylabel(ylabel)
234
- if xlim is not None:
235
- set_xlim(xlim)
236
- if ylim is not None:
237
- set_ylim(ylim)
238
- return im
239
-
240
-
241
- def spec_im(
242
- audio: torch.Tensor,
243
- figsize=(15, 5),
244
- colorbar=False,
245
- colorbar_format=None,
246
- figure=None,
247
- labels=True,
248
- **kwargs,
249
- ) -> Image:
250
- audio = torch.as_tensor(audio)
251
- if labels:
252
- kwargs.setdefault("xlabel", "Time [s]")
253
- kwargs.setdefault("ylabel", "Frequency [Hz]")
254
- n_fft = kwargs.setdefault("n_fft", 1024)
255
- hop = kwargs.setdefault("hop", 512)
256
- w = torch.hann_window(n_fft, device=audio.device)
257
- spec = torch.stft(audio, n_fft, hop, window=w, return_complex=False)
258
- spec = spec.div_(w.pow(2).sum())
259
- spec = torch.view_as_complex(spec).abs().clamp_min(1e-12).log10().mul(10)
260
- kwargs.setdefault("vmax", max(0.0, spec.max().item()))
261
-
262
- if figure is None:
263
- figure = plt.figure(figsize=figsize)
264
- figure.set_tight_layout(True)
265
- if spec.dim() > 2:
266
- spec = spec.squeeze(0)
267
- im = specshow(spec, **kwargs)
268
- if colorbar:
269
- ckwargs = {}
270
- if "ax" in kwargs:
271
- if colorbar_format is None:
272
- if kwargs.get("vmin", None) is not None or kwargs.get("vmax", None) is not None:
273
- colorbar_format = "%+2.0f dB"
274
- ckwargs = {"ax": kwargs["ax"]}
275
- plt.colorbar(im, format=colorbar_format, **ckwargs)
276
- figure.canvas.draw()
277
- return Image.frombytes("RGB", figure.canvas.get_width_height(), figure.canvas.tostring_rgb())
278
-
279
-
280
- def cleanup_tmp(filter: List[str] = [], hours_keep=2):
281
- filter.append("p232")
282
- logger.info(f"Filter: {filter}")
283
- # Cleanup some old wav files
284
- if os.path.exists("/tmp"):
285
- for f in glob.glob("/tmp/*"):
286
- print(f"Got file {f}")
287
- is_old = (time.time() - os.path.getmtime(f)) / 3600 > hours_keep
288
- filtered = any(filt in f for filt in filter if filt is not None)
289
- if is_old and not filtered:
290
- try:
291
- os.remove(f)
292
- logger.info(f"Removed file {f}")
293
- except Exception as e:
294
- logger.warning(f"failed to remove file {f}: {e}")
295
-
296
-
297
- def toggle(choice):
298
  if choice == "mic":
299
  return gr.update(visible=True, value=None), gr.update(visible=False, value=None)
300
  else:
301
  return gr.update(visible=False, value=None), gr.update(visible=True, value=None)
302
 
303
 
304
- with gr.Blocks() as demo:
305
- with gr.Row():
306
- gr.Markdown(
307
- """
308
- ## DeepFilterNet2 Demo\
309
-
310
- This demo denoises audio files using DeepFilterNet. Try it with your own voice!
311
- """
312
- )
 
 
 
 
 
 
 
 
 
 
 
 
313
  with gr.Row():
314
- with gr.Column():
315
- radio = gr.Radio(
316
- ["mic", "file"], value="file", label="How would you like to upload your audio?"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
317
  )
318
- mic_input = gr.Mic(label="Input", type="filepath", visible=False)
319
- audio_file = gr.Audio(type="filepath", label="Input", visible=True)
320
- inputs = [
321
- audio_file,
322
- gr.Dropdown(
323
- label="Add background noise",
324
- choices=list(NOISES.keys()),
325
- value="None",
326
- ),
327
- gr.Dropdown(
328
- label="Noise Level (SNR)",
329
- choices=["-5", "0", "10", "20"],
330
- value="10",
331
- ),
332
- mic_input,
333
- ]
334
- btn = gr.Button("Generate")
335
- with gr.Column():
336
- outputs = [
337
- # gr.Video(type="filepath", label="Noisy audio"),
338
- gr.Audio(type="filepath", label="Noisy audio"),
339
- gr.Image(label="Noisy spectrogram"),
340
- # gr.Video(type="filepath", label="Enhanced audio"),
341
- gr.Audio(type="filepath", label="Enhanced audio"),
342
- gr.Image(label="Enhanced spectrogram"),
343
- ]
344
- btn.click(fn=demo_fn, inputs=inputs, outputs=outputs, api_name='denoise')
345
- radio.change(toggle, radio, [mic_input, audio_file])
 
 
 
 
346
  gr.Examples(
347
- [
348
  ["./samples/p232_013_clean.wav", "Kitchen", "10"],
349
  ["./samples/p232_013_clean.wav", "Cafe", "10"],
350
  ["./samples/p232_019_clean.wav", "Cafe", "10"],
351
  ["./samples/p232_019_clean.wav", "River", "10"],
352
  ],
353
- fn=demo_fn,
354
- inputs=inputs,
355
- outputs=outputs,
356
  cache_examples=True,
357
- ),
358
- gr.Markdown(open("usage.md").read())
359
-
360
- cleanup_tmp()
361
- # demo.launch(enable_queue=True)
362
- # demo.launch()
363
- demo.queue().launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
364
 
 
 
365
 
 
 
 
 
 
 
 
 
3
  import os
4
  import tempfile
5
  import time
6
+ from pathlib import Path
7
  from typing import List, Optional, Tuple, Union
 
8
  import subprocess
9
+ from dataclasses import dataclass
10
 
11
  import gradio as gr
12
  import matplotlib.pyplot as plt
 
21
  from df.enhance import enhance, init_df, load_audio, save_audio
22
  from df.io import resample
23
 
24
+ # ============================================================================
25
+ # Configuration and Setup
26
+ # ============================================================================
27
+
28
+ @dataclass
29
+ class AppConfig:
30
+ """Application configuration"""
31
+ device: torch.device
32
+ sample_rate: int = 48000
33
+ max_duration_seconds: int = 3600
34
+ cleanup_hours: int = 2
35
+ temp_dir: str = "/tmp"
36
+ model_path: str = "./DeepFilterNet2"
37
+ fade_duration: float = 0.15
38
+
39
+
40
+ class AudioProcessor:
41
+ """Handles audio processing operations"""
42
+
43
+ def __init__(self, model, df, config: AppConfig):
44
+ self.model = model
45
+ self.df = df
46
+ self.config = config
47
+
48
+ def mix_at_snr(self, clean: Tensor, noise: Tensor, snr: float, eps: float = 1e-10) -> Tuple[Tensor, Tensor, Tensor]:
49
+ """Mix clean and noise signal at a given SNR with improved error handling.
50
+
51
+ Args:
52
+ clean: 1D Tensor with the clean signal to mix.
53
+ noise: 1D Tensor of shape.
54
+ snr: Signal to noise ratio in dB.
55
+ eps: Small epsilon for numerical stability.
56
+
57
+ Returns:
58
+ clean: 1D Tensor with gain changed according to the snr.
59
+ noise: 1D Tensor with the combined noise channels.
60
+ mix: 1D Tensor with added clean and noise signals.
61
+ """
62
+ clean = torch.as_tensor(clean).mean(0, keepdim=True)
63
+ noise = torch.as_tensor(noise).mean(0, keepdim=True)
64
+
65
+ # Repeat noise if shorter than clean signal
66
+ if noise.shape[1] < clean.shape[1]:
67
+ repeats = int(math.ceil(clean.shape[1] / noise.shape[1]))
68
+ noise = noise.repeat((1, repeats))
69
+
70
+ # Random starting point for noise
71
+ max_start = int(noise.shape[1] - clean.shape[1])
72
+ start = torch.randint(0, max_start, ()).item() if max_start > 0 else 0
73
+ noise = noise[:, start : start + clean.shape[1]]
74
+
75
+ # Calculate SNR scaling
76
+ E_speech = torch.mean(clean.pow(2)) + eps
77
+ E_noise = torch.mean(noise.pow(2)) + eps
78
+ K = torch.sqrt((E_noise / E_speech) * 10 ** (snr / 10) + eps)
79
+ noise = noise / K
80
+ mixture = clean + noise
81
+
82
+ # Check for clipping
83
+ assert torch.isfinite(mixture).all(), "Non-finite values detected in mixture"
84
+ max_m = mixture.abs().max()
85
+ if max_m > 1:
86
+ logger.warning(f"Clipping detected during mixing. Reducing gain by {1/max_m:.3f}")
87
+ clean, noise, mixture = clean / max_m, noise / max_m, mixture / max_m
88
+
89
+ return clean, noise, mixture
90
+
91
+ def enhance_audio(self, audio: Tensor) -> Tensor:
92
+ """Enhance audio using the DeepFilterNet model.
93
+
94
+ Args:
95
+ audio: Input audio tensor
96
+
97
+ Returns:
98
+ Enhanced audio tensor
99
+ """
100
+ logger.info(f"Enhancing audio with shape {audio.shape}")
101
+ with torch.no_grad():
102
+ enhanced = enhance(self.model, self.df, audio)
103
+
104
+ # Apply fade-in to avoid clicks
105
+ sr = self.config.sample_rate
106
+ fade_samples = int(sr * self.config.fade_duration)
107
+ lim = torch.linspace(0.0, 1.0, fade_samples).unsqueeze(0)
108
+ lim = torch.cat((lim, torch.ones(1, enhanced.shape[1] - lim.shape[1])), dim=1)
109
+ enhanced = enhanced * lim
110
+
111
+ return enhanced
112
+
113
+
114
+ class AudioLoader:
115
+ """Handles audio loading from various sources"""
116
+
117
+ @staticmethod
118
+ def ensure_wav(filepath: str) -> str:
119
+ """Convert audio files to WAV using ffmpeg if needed.
120
+
121
+ Args:
122
+ filepath: Path to input audio file
123
+
124
+ Returns:
125
+ Path to WAV file
126
+ """
127
+ if not filepath:
128
+ return filepath
129
+
130
+ file_ext = Path(filepath).suffix.lower()
131
+ if file_ext in ['.mp3', '.m4a', '.ogg', '.flac', '.aac']:
132
+ wav_path = str(Path(filepath).with_suffix('.wav'))
133
+ try:
134
+ subprocess.run(
135
+ ["ffmpeg", "-y", "-i", filepath, "-acodec", "pcm_s16le", wav_path],
136
+ check=True,
137
+ capture_output=True
138
+ )
139
+ logger.info(f"Converted {file_ext} to WAV: {wav_path}")
140
+ return wav_path
141
+ except subprocess.CalledProcessError as e:
142
+ logger.error(f"FFmpeg conversion failed: {e.stderr}")
143
+ raise
144
+ return filepath
145
+
146
+ @staticmethod
147
+ def load_audio_gradio(
148
+ audio_or_file: Union[None, str, Tuple[int, np.ndarray]],
149
+ sr: int
150
+ ) -> Optional[Tuple[Tensor, AudioMetaData]]:
151
+ """Load audio from Gradio input (file path or recorded audio).
152
+
153
+ Args:
154
+ audio_or_file: Either a file path string or tuple of (sample_rate, audio_array)
155
+ sr: Target sample rate
156
+
157
+ Returns:
158
+ Tuple of (audio tensor, metadata) or None
159
+ """
160
+ if audio_or_file is None:
161
+ return None
162
+
163
+ if isinstance(audio_or_file, str):
164
+ if audio_or_file.lower() == "none":
165
+ return None
166
+ # Load from file
167
+ audio_or_file = AudioLoader.ensure_wav(audio_or_file)
168
+ audio, meta = load_audio(audio_or_file, sr)
169
+ else:
170
+ # Load from Gradio recording
171
+ meta = AudioMetaData(-1, -1, -1, -1, "")
172
+ assert isinstance(audio_or_file, (tuple, list))
173
+ meta.sample_rate, audio_np = audio_or_file
174
+
175
+ # Handle different array shapes
176
+ audio_np = audio_np.reshape(audio_np.shape[0], -1).T
177
+
178
+ # Convert to float32
179
+ if audio_np.dtype == np.int16:
180
+ audio_np = (audio_np / (1 << 15)).astype(np.float32)
181
+ elif audio_np.dtype == np.int32:
182
+ audio_np = (audio_np / (1 << 31)).astype(np.float32)
183
+
184
+ audio = resample(torch.from_numpy(audio_np), meta.sample_rate, sr)
185
+
186
+ return audio, meta
187
+
188
+
189
+ class SpectrogramVisualizer:
190
+ """Handles spectrogram visualization"""
191
+
192
+ def __init__(self, figsize: Tuple[float, float] = (15.2, 4)):
193
+ self.figsize = figsize
194
+ self.fig_noisy, self.ax_noisy = plt.subplots(figsize=figsize)
195
+ self.fig_noisy.set_tight_layout(True)
196
+ self.fig_enh, self.ax_enh = plt.subplots(figsize=figsize)
197
+ self.fig_enh.set_tight_layout(True)
198
+
199
+ def specshow(
200
+ self,
201
+ spec: Union[Tensor, np.ndarray],
202
+ ax: Optional[plt.Axes] = None,
203
+ title: Optional[str] = None,
204
+ xlabel: Optional[str] = None,
205
+ ylabel: Optional[str] = None,
206
+ sr: int = 48000,
207
+ n_fft: Optional[int] = None,
208
+ hop: Optional[int] = None,
209
+ vmin: float = -100,
210
+ vmax: float = 0,
211
+ cmap: str = "inferno",
212
+ ):
213
+ """Plot a spectrogram of shape [F, T]"""
214
+ spec_np = spec.cpu().numpy() if isinstance(spec, torch.Tensor) else spec
215
+
216
+ if n_fft is None:
217
+ n_fft = spec.shape[0] * 2 if spec.shape[0] % 2 == 0 else (spec.shape[0] - 1) * 2
218
+ hop = hop or n_fft // 4
219
+
220
+ t = np.arange(0, spec_np.shape[-1]) * hop / sr
221
+ f = np.arange(0, spec_np.shape[0]) * sr // 2 / (n_fft // 2) / 1000
222
+
223
+ im = ax.pcolormesh(
224
+ t, f, spec_np,
225
+ rasterized=True,
226
+ shading="auto",
227
+ vmin=vmin,
228
+ vmax=vmax,
229
+ cmap=cmap
230
+ )
231
+
232
+ if title:
233
+ ax.set_title(title)
234
+ if xlabel:
235
+ ax.set_xlabel(xlabel)
236
+ if ylabel:
237
+ ax.set_ylabel(ylabel)
238
+
239
+ return im
240
+
241
+ def create_spectrogram(
242
+ self,
243
+ audio: Tensor,
244
+ figure: plt.Figure,
245
+ ax: plt.Axes,
246
+ sr: int = 48000,
247
+ n_fft: int = 1024,
248
+ hop: int = 512,
249
+ title: Optional[str] = None,
250
+ ) -> Image:
251
+ """Create spectrogram image from audio tensor"""
252
+ audio = torch.as_tensor(audio)
253
+
254
+ # Compute STFT
255
+ w = torch.hann_window(n_fft, device=audio.device)
256
+ spec = torch.stft(audio, n_fft, hop, window=w, return_complex=False)
257
+ spec = spec.div_(w.pow(2).sum())
258
+ spec = torch.view_as_complex(spec).abs().clamp_min(1e-12).log10().mul(10)
259
+
260
+ vmax = max(0.0, spec.max().item())
261
+
262
+ if spec.dim() > 2:
263
+ spec = spec.squeeze(0)
264
+
265
+ ax.clear()
266
+ self.specshow(
267
+ spec,
268
+ ax=ax,
269
+ title=title,
270
+ xlabel="Time [s]",
271
+ ylabel="Frequency [kHz]",
272
+ sr=sr,
273
+ n_fft=n_fft,
274
+ hop=hop,
275
+ vmax=vmax,
276
+ )
277
+
278
+ figure.canvas.draw()
279
+ return Image.frombytes(
280
+ "RGB",
281
+ figure.canvas.get_width_height(),
282
+ figure.canvas.tostring_rgb()
283
+ )
284
 
 
 
 
 
 
 
 
 
285
 
286
+ class FileManager:
287
+ """Manages temporary file cleanup"""
288
+
289
+ @staticmethod
290
+ def cleanup_tmp(filter_list: List[str] = None, hours_keep: int = 2, temp_dir: str = "/tmp"):
291
+ """Clean up old temporary files.
292
+
293
+ Args:
294
+ filter_list: List of file patterns to keep
295
+ hours_keep: Number of hours to keep files
296
+ temp_dir: Temporary directory path
297
+ """
298
+ if filter_list is None:
299
+ filter_list = []
300
+ filter_list.append("p232")
301
+
302
+ if not os.path.exists(temp_dir):
303
+ return
304
+
305
+ logger.info(f"Cleaning up temporary files older than {hours_keep} hours")
306
+ cleaned = 0
307
+
308
+ for filepath in glob.glob(os.path.join(temp_dir, "*")):
309
+ try:
310
+ is_old = (time.time() - os.path.getmtime(filepath)) / 3600 > hours_keep
311
+ filtered = any(filt in filepath for filt in filter_list if filt is not None)
312
+
313
+ if is_old and not filtered:
314
+ os.remove(filepath)
315
+ cleaned += 1
316
+ logger.debug(f"Removed file {filepath}")
317
+ except Exception as e:
318
+ logger.warning(f"Failed to remove file {filepath}: {e}")
319
+
320
+ if cleaned > 0:
321
+ logger.info(f"Cleaned up {cleaned} temporary files")
322
+
323
+
324
+ # ============================================================================
325
+ # Initialize Application
326
+ # ============================================================================
327
+
328
+ # Setup configuration
329
+ app_config = AppConfig(
330
+ device=torch.device("cuda" if torch.cuda.is_available() else "cpu")
331
+ )
332
+
333
+ # Initialize model
334
+ logger.info(f"Loading DeepFilterNet2 model on {app_config.device}")
335
+ model, df, _ = init_df(app_config.model_path, config_allow_defaults=True)
336
+ model = model.to(device=app_config.device).eval()
337
+
338
+ # Initialize components
339
+ audio_processor = AudioProcessor(model, df, app_config)
340
+ audio_loader = AudioLoader()
341
+ visualizer = SpectrogramVisualizer()
342
+ file_manager = FileManager()
343
+
344
+ # Noise options
345
  NOISES = {
346
  "None": None,
347
  "Kitchen": "samples/dkitchen.wav",
 
351
  }
352
 
353
 
354
+ # ============================================================================
355
+ # Main Processing Function
356
+ # ============================================================================
357
 
358
+ def process_audio(
359
+ speech_file: Optional[str],
360
+ noise_type: str,
361
+ snr: int,
362
+ mic_input: Optional[str] = None,
363
+ ) -> Tuple[str, Image, str, Image]:
364
+ """Main audio processing pipeline.
365
+
366
  Args:
367
+ speech_file: Path to uploaded audio file
368
+ noise_type: Type of background noise to add
369
+ snr: Signal-to-noise ratio in dB
370
+ mic_input: Path to microphone recording
371
+
372
  Returns:
373
+ Tuple of (noisy_audio_path, noisy_spectrogram, enhanced_audio_path, enhanced_spectrogram)
 
 
 
374
  """
375
+ try:
376
+ # Use mic input if available
377
+ if mic_input:
378
+ speech_file = mic_input
379
+
380
+ sr = app_config.sample_rate
381
+ logger.info(f"Processing: file={speech_file}, noise={noise_type}, snr={snr}")
382
+
383
+ # Load input audio
384
+ if speech_file is not None:
385
+ speech_file = audio_loader.ensure_wav(speech_file)
386
+ sample, meta = load_audio(speech_file, sr)
387
+
388
+ # Limit duration
389
+ max_len = app_config.max_duration_seconds * sr
390
+ if sample.shape[-1] > max_len:
391
+ logger.warning(f"Audio too long, truncating to {app_config.max_duration_seconds}s")
392
+ start = torch.randint(0, sample.shape[-1] - max_len, ()).item()
393
+ sample = sample[..., start : start + max_len]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
394
  else:
395
+ # Use default sample
396
+ sample, meta = load_audio("samples/p232_013_clean.wav", sr)
397
+ sample = sample[..., : app_config.max_duration_seconds * sr]
398
+
399
+ # Convert to mono if needed
400
+ if sample.dim() > 1 and sample.shape[0] > 1:
401
+ logger.info(f"Converting from {sample.shape[0]} channels to mono")
402
+ sample = sample.mean(dim=0, keepdim=True)
403
+
404
+ logger.info(f"Loaded audio with shape {sample.shape}")
405
+
406
+ # Add noise if specified
407
+ noise_fn = NOISES.get(noise_type)
408
+ if noise_fn is not None:
409
+ noise, _ = load_audio(noise_fn, sr)
410
+ logger.info(f"Adding {noise_type} noise at {snr} dB SNR")
411
+ _, _, sample = audio_processor.mix_at_snr(sample, noise, int(snr))
412
+
413
+ # Enhance audio
414
+ enhanced = audio_processor.enhance_audio(sample)
415
+ logger.info("Audio enhancement completed")
416
+
417
+ # Resample if needed
418
+ if meta.sample_rate != sr and meta.sample_rate > 0:
419
+ enhanced = resample(enhanced, sr, meta.sample_rate)
420
+ sample = resample(sample, sr, meta.sample_rate)
421
+ sr = meta.sample_rate
422
+
423
+ # Save audio files
424
+ noisy_wav = tempfile.NamedTemporaryFile(suffix="_noisy.wav", delete=False).name
425
+ save_audio(noisy_wav, sample, sr)
426
+
427
+ enhanced_wav = tempfile.NamedTemporaryFile(suffix="_enhanced.wav", delete=False).name
428
+ save_audio(enhanced_wav, enhanced, sr)
429
+
430
+ logger.info(f"Saved outputs: {noisy_wav}, {enhanced_wav}")
431
+
432
+ # Create spectrograms
433
+ noisy_spec = visualizer.create_spectrogram(
434
+ sample,
435
+ visualizer.fig_noisy,
436
+ visualizer.ax_noisy,
437
+ sr=sr,
438
+ title="Noisy Audio Spectrogram"
439
+ )
440
+
441
+ enhanced_spec = visualizer.create_spectrogram(
442
+ enhanced,
443
+ visualizer.fig_enh,
444
+ visualizer.ax_enh,
445
+ sr=sr,
446
+ title="Enhanced Audio Spectrogram"
447
+ )
448
+
449
+ # Cleanup old files
450
+ filter_files = [speech_file, noisy_wav, enhanced_wav]
451
+ if mic_input:
452
+ filter_files.append(mic_input)
453
+ file_manager.cleanup_tmp(filter_files, app_config.cleanup_hours)
454
+
455
+ return noisy_wav, noisy_spec, enhanced_wav, enhanced_spec
456
+
457
+ except Exception as e:
458
+ logger.error(f"Error processing audio: {e}", exc_info=True)
459
+ raise gr.Error(f"Processing failed: {str(e)}")
460
+
461
+
462
+ def toggle_input_mode(choice: str):
463
+ """Toggle between microphone and file upload."""
 
 
 
 
 
 
 
 
 
 
464
  if choice == "mic":
465
  return gr.update(visible=True, value=None), gr.update(visible=False, value=None)
466
  else:
467
  return gr.update(visible=False, value=None), gr.update(visible=True, value=None)
468
 
469
 
470
+ # ============================================================================
471
+ # Gradio Interface
472
+ # ============================================================================
473
+
474
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
475
+ gr.Markdown(
476
+ """
477
+ # 🎵 DeepFilterNet2 Audio Denoising Demo
478
+
479
+ Remove background noise from your audio recordings using state-of-the-art deep learning.
480
+ Upload an audio file or record directly, optionally add synthetic noise, and enhance the quality.
481
+
482
+ **Features:**
483
+ - Support for multiple audio formats (MP3, WAV, M4A, OGG, FLAC)
484
+ - Real-time microphone recording
485
+ - Customizable background noise injection
486
+ - Visual spectrogram comparison
487
+ - Up to 1 hour of audio processing
488
+ """
489
+ )
490
+
491
  with gr.Row():
492
+ with gr.Column(scale=1):
493
+ gr.Markdown("### Input Settings")
494
+
495
+ input_mode = gr.Radio(
496
+ ["file", "mic"],
497
+ value="file",
498
+ label="Input Method",
499
+ info="Choose how to provide your audio"
500
+ )
501
+
502
+ audio_file = gr.Audio(
503
+ type="filepath",
504
+ label="Upload Audio File",
505
+ visible=True
506
+ )
507
+
508
+ mic_input = gr.Audio(
509
+ sources=["microphone"],
510
+ type="filepath",
511
+ label="Record Audio",
512
+ visible=False
513
  )
514
+
515
+ gr.Markdown("### Enhancement Settings")
516
+
517
+ noise_type = gr.Dropdown(
518
+ label="Background Noise Type",
519
+ choices=list(NOISES.keys()),
520
+ value="None",
521
+ info="Add synthetic background noise for testing"
522
+ )
523
+
524
+ snr = gr.Dropdown(
525
+ label="Signal-to-Noise Ratio (dB)",
526
+ choices=["-5", "0", "10", "20"],
527
+ value="10",
528
+ info="Higher values = less noise"
529
+ )
530
+
531
+ process_btn = gr.Button("🚀 Denoise Audio", variant="primary", size="lg")
532
+
533
+ with gr.Column(scale=2):
534
+ gr.Markdown("### Results")
535
+
536
+ with gr.Tab("Noisy Audio"):
537
+ noisy_audio = gr.Audio(type="filepath", label="Noisy Audio")
538
+ noisy_spec = gr.Image(label="Noisy Spectrogram")
539
+
540
+ with gr.Tab("Enhanced Audio"):
541
+ enhanced_audio = gr.Audio(type="filepath", label="Enhanced Audio")
542
+ enhanced_spec = gr.Image(label="Enhanced Spectrogram")
543
+
544
+ # Examples
545
+ gr.Markdown("### 📝 Example Inputs")
546
  gr.Examples(
547
+ examples=[
548
  ["./samples/p232_013_clean.wav", "Kitchen", "10"],
549
  ["./samples/p232_013_clean.wav", "Cafe", "10"],
550
  ["./samples/p232_019_clean.wav", "Cafe", "10"],
551
  ["./samples/p232_019_clean.wav", "River", "10"],
552
  ],
553
+ inputs=[audio_file, noise_type, snr],
554
+ outputs=[noisy_audio, noisy_spec, enhanced_audio, enhanced_spec],
555
+ fn=process_audio,
556
  cache_examples=True,
557
+ label="Try these examples",
558
+ )
559
+
560
+ # Information
561
+ gr.Markdown(
562
+ """
563
+ ### ℹ️ How It Works
564
+
565
+ 1. **Upload or Record**: Choose your input method and provide audio
566
+ 2. **Configure** (Optional): Add synthetic noise for testing the denoiser
567
+ 3. **Process**: Click "Denoise Audio" to enhance your recording
568
+ 4. **Compare**: View spectrograms and listen to before/after results
569
+
570
+ ### 📊 Technical Details
571
+
572
+ - **Model**: DeepFilterNet2 - Real-time speech enhancement
573
+ - **Max Duration**: 1 hour per file
574
+ - **Sample Rate**: 48 kHz
575
+ - **Supported Formats**: WAV, MP3, M4A, OGG, FLAC, AAC
576
+
577
+ ### 🎯 Best Results
578
+
579
+ - Use clear speech recordings
580
+ - Avoid extreme clipping or distortion
581
+ - For best quality, use WAV format at 48kHz
582
+ """
583
+ )
584
+
585
+ # Event handlers
586
+ process_btn.click(
587
+ fn=process_audio,
588
+ inputs=[audio_file, noise_type, snr, mic_input],
589
+ outputs=[noisy_audio, noisy_spec, enhanced_audio, enhanced_spec],
590
+ api_name="denoise",
591
+ )
592
+
593
+ input_mode.change(
594
+ fn=toggle_input_mode,
595
+ inputs=input_mode,
596
+ outputs=[mic_input, audio_file],
597
+ )
598
 
599
+ # Initial cleanup
600
+ file_manager.cleanup_tmp()
601
 
602
+ # Launch application
603
+ if __name__ == "__main__":
604
+ demo.queue().launch(
605
+ server_name="0.0.0.0",
606
+ server_port=7860,
607
+ share=True,
608
+ )