humairmunirawn commited on
Commit
3132eeb
·
verified ·
1 Parent(s): 3f3d91b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +258 -19
app.py CHANGED
@@ -4,18 +4,20 @@ import os
4
  import tempfile
5
  import time
6
  from typing import List, Optional, Tuple, Union
 
7
 
8
  import gradio as gr
9
  import matplotlib.pyplot as plt
10
  import numpy as np
11
  import torch
 
12
  from loguru import logger
13
  from PIL import Image
14
  from torch import Tensor
15
- from torchaudio.backend.common import AudioMetaData
16
 
17
  from df import config
18
- from df.enhance import enhance, init_df, load_audio, save_audio
19
  from df.io import resample
20
 
21
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@@ -40,6 +42,138 @@ NOISES = {
40
  }
41
 
42
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
  def mix_at_snr(clean, noise, snr, eps=1e-10):
44
  """Mix clean and noise signal at a given SNR.
45
 
@@ -67,7 +201,7 @@ def mix_at_snr(clean, noise, snr, eps=1e-10):
67
  K = torch.sqrt((E_noise / E_speech) * 10 ** (snr / 10) + eps)
68
  noise = noise / K
69
  mixture = clean + noise
70
- logger.debug("mixture: {mixture.shape}")
71
  assert torch.isfinite(mixture).all()
72
  max_m = mixture.abs().max()
73
  if max_m > 1:
@@ -79,36 +213,77 @@ def mix_at_snr(clean, noise, snr, eps=1e-10):
79
  def load_audio_gradio(
80
  audio_or_file: Union[None, str, Tuple[int, np.ndarray]], sr: int
81
  ) -> Optional[Tuple[Tensor, AudioMetaData]]:
 
 
 
 
 
 
 
 
 
82
  if audio_or_file is None:
83
  return None
 
84
  if isinstance(audio_or_file, str):
85
  if audio_or_file.lower() == "none":
86
  return None
87
- # First try default format
88
  audio, meta = load_audio(audio_or_file, sr)
89
  else:
90
- meta = AudioMetaData(-1, -1, -1, -1, "")
 
 
 
 
 
 
 
91
  assert isinstance(audio_or_file, (tuple, list))
92
- meta.sample_rate, audio_np = audio_or_file
93
- # Gradio documentation says, the shape is [samples, 2], but apparently sometimes its not.
 
94
  audio_np = audio_np.reshape(audio_np.shape[0], -1).T
 
 
95
  if audio_np.dtype == np.int16:
96
  audio_np = (audio_np / (1 << 15)).astype(np.float32)
97
  elif audio_np.dtype == np.int32:
98
  audio_np = (audio_np / (1 << 31)).astype(np.float32)
99
- audio = resample(torch.from_numpy(audio_np), meta.sample_rate, sr)
 
 
 
 
 
 
 
 
100
  return audio, meta
101
 
102
 
103
  def demo_fn(speech_upl: str, noise_type: str, snr: int, mic_input: Optional[str] = None):
 
 
 
 
 
 
 
 
 
 
 
104
  if mic_input:
105
  speech_upl = mic_input
 
106
  sr = config("sr", 48000, int, section="df")
107
  logger.info(f"Got parameters speech_upl: {speech_upl}, noise: {noise_type}, snr: {snr}")
108
  snr = int(snr)
109
  noise_fn = NOISES[noise_type]
110
  meta = AudioMetaData(-1, -1, -1, -1, "")
111
  max_s = 10 # limit to 10 seconds
 
112
  if speech_upl is not None:
113
  sample, meta = load_audio(speech_upl, sr)
114
  max_len = max_s * sr
@@ -118,39 +293,56 @@ def demo_fn(speech_upl: str, noise_type: str, snr: int, mic_input: Optional[str]
118
  else:
119
  sample, meta = load_audio("samples/p232_013_clean.wav", sr)
120
  sample = sample[..., : max_s * sr]
 
121
  if sample.dim() > 1 and sample.shape[0] > 1:
122
  assert (
123
  sample.shape[1] > sample.shape[0]
124
  ), f"Expecting channels first, but got {sample.shape}"
125
  sample = sample.mean(dim=0, keepdim=True)
 
126
  logger.info(f"Loaded sample with shape {sample.shape}")
 
127
  if noise_fn is not None:
128
- noise, _ = load_audio(noise_fn, sr) # type: ignore
129
  logger.info(f"Loaded noise with shape {noise.shape}")
130
  _, _, sample = mix_at_snr(sample, noise, snr)
 
131
  logger.info("Start denoising audio")
132
  enhanced = enhance(model, df, sample)
133
  logger.info("Denoising finished")
 
 
134
  lim = torch.linspace(0.0, 1.0, int(sr * 0.15)).unsqueeze(0)
135
  lim = torch.cat((lim, torch.ones(1, enhanced.shape[1] - lim.shape[1])), dim=1)
136
  enhanced = enhanced * lim
 
 
137
  if meta.sample_rate != sr:
138
- enhanced = resample(enhanced, sr, meta.sample_rate)
139
- sample = resample(sample, sr, meta.sample_rate)
140
  sr = meta.sample_rate
141
- noisy_wav = tempfile.NamedTemporaryFile(suffix="noisy.wav", delete=False).name
 
 
142
  save_audio(noisy_wav, sample, sr)
143
- enhanced_wav = tempfile.NamedTemporaryFile(suffix="enhanced.wav", delete=False).name
 
144
  save_audio(enhanced_wav, enhanced, sr)
 
145
  logger.info(f"saved audios: {noisy_wav}, {enhanced_wav}")
 
 
146
  ax_noisy.clear()
147
  ax_enh.clear()
148
  noisy_im = spec_im(sample, sr=sr, figure=fig_noisy, ax=ax_noisy)
149
  enh_im = spec_im(enhanced, sr=sr, figure=fig_enh, ax=ax_enh)
 
 
150
  filter = [speech_upl, noisy_wav, enhanced_wav]
151
  if mic_input is not None and mic_input != "":
152
  filter.append(mic_input)
153
  cleanup_tmp(filter)
 
154
  return noisy_wav, noisy_im, enhanced_wav, enh_im
155
 
156
 
@@ -186,19 +378,25 @@ def specshow(
186
  set_ylabel = plt.ylabel
187
  set_xlim = plt.xlim
188
  set_ylim = plt.ylim
 
189
  if n_fft is None:
190
  if spec.shape[0] % 2 == 0:
191
  n_fft = spec.shape[0] * 2
192
  else:
193
  n_fft = (spec.shape[0] - 1) * 2
 
194
  hop = hop or n_fft // 4
 
195
  if t is None:
196
  t = np.arange(0, spec_np.shape[-1]) * hop / sr
 
197
  if f is None:
198
  f = np.arange(0, spec_np.shape[0]) * sr // 2 / (n_fft // 2) / 1000
 
199
  im = ax.pcolormesh(
200
  t, f, spec_np, rasterized=True, shading="auto", vmin=vmin, vmax=vmax, cmap=cmap
201
  )
 
202
  if title is not None:
203
  set_title(title)
204
  if xlabel is not None:
@@ -209,6 +407,7 @@ def specshow(
209
  set_xlim(xlim)
210
  if ylim is not None:
211
  set_ylim(ylim)
 
212
  return im
213
 
214
 
@@ -221,24 +420,44 @@ def spec_im(
221
  labels=True,
222
  **kwargs,
223
  ) -> Image:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
224
  audio = torch.as_tensor(audio)
 
225
  if labels:
226
  kwargs.setdefault("xlabel", "Time [s]")
227
  kwargs.setdefault("ylabel", "Frequency [Hz]")
 
228
  n_fft = kwargs.setdefault("n_fft", 1024)
229
  hop = kwargs.setdefault("hop", 512)
 
230
  w = torch.hann_window(n_fft, device=audio.device)
231
  spec = torch.stft(audio, n_fft, hop, window=w, return_complex=False)
232
  spec = spec.div_(w.pow(2).sum())
233
  spec = torch.view_as_complex(spec).abs().clamp_min(1e-12).log10().mul(10)
234
  kwargs.setdefault("vmax", max(0.0, spec.max().item()))
235
-
236
  if figure is None:
237
  figure = plt.figure(figsize=figsize)
238
  figure.set_tight_layout(True)
 
239
  if spec.dim() > 2:
240
  spec = spec.squeeze(0)
 
241
  im = specshow(spec, **kwargs)
 
242
  if colorbar:
243
  ckwargs = {}
244
  if "ax" in kwargs:
@@ -247,13 +466,21 @@ def spec_im(
247
  colorbar_format = "%+2.0f dB"
248
  ckwargs = {"ax": kwargs["ax"]}
249
  plt.colorbar(im, format=colorbar_format, **ckwargs)
 
250
  figure.canvas.draw()
251
  return Image.frombytes("RGB", figure.canvas.get_width_height(), figure.canvas.tostring_rgb())
252
 
253
 
254
  def cleanup_tmp(filter: List[str] = [], hours_keep=2):
 
 
 
 
 
 
255
  filter.append("p232")
256
  logger.info(f"Filter: {filter}")
 
257
  # Cleanup some old wav files
258
  if os.path.exists("/tmp"):
259
  for f in glob.glob("/tmp/*"):
@@ -269,21 +496,31 @@ def cleanup_tmp(filter: List[str] = [], hours_keep=2):
269
 
270
 
271
  def toggle(choice):
 
 
 
 
 
 
 
 
272
  if choice == "mic":
273
  return gr.update(visible=True, value=None), gr.update(visible=False, value=None)
274
  else:
275
  return gr.update(visible=False, value=None), gr.update(visible=True, value=None)
276
 
277
 
 
278
  with gr.Blocks() as demo:
279
  with gr.Row():
280
  gr.Markdown(
281
  """
282
- ## DeepFilterNet2 Demo\
283
 
284
  This demo denoises audio files using DeepFilterNet. Try it with your own voice!
285
  """
286
  )
 
287
  with gr.Row():
288
  with gr.Column():
289
  radio = gr.Radio(
@@ -306,17 +543,18 @@ with gr.Blocks() as demo:
306
  mic_input,
307
  ]
308
  btn = gr.Button("Generate")
 
309
  with gr.Column():
310
  outputs = [
311
- # gr.Video(type="filepath", label="Noisy audio"),
312
  gr.Audio(type="filepath", label="Noisy audio"),
313
  gr.Image(label="Noisy spectrogram"),
314
- # gr.Video(type="filepath", label="Enhanced audio"),
315
  gr.Audio(type="filepath", label="Enhanced audio"),
316
  gr.Image(label="Enhanced spectrogram"),
317
  ]
 
318
  btn.click(fn=demo_fn, inputs=inputs, outputs=outputs, api_name='denoise')
319
  radio.change(toggle, radio, [mic_input, audio_file])
 
320
  gr.Examples(
321
  [
322
  ["./samples/p232_013_clean.wav", "Kitchen", "10"],
@@ -328,8 +566,9 @@ with gr.Blocks() as demo:
328
  inputs=inputs,
329
  outputs=outputs,
330
  cache_examples=True,
331
- ),
 
332
  gr.Markdown(open("usage.md").read())
333
 
334
  cleanup_tmp()
335
- demo.launch(enable_queue=True)
 
4
  import tempfile
5
  import time
6
  from typing import List, Optional, Tuple, Union
7
+ from dataclasses import dataclass
8
 
9
  import gradio as gr
10
  import matplotlib.pyplot as plt
11
  import numpy as np
12
  import torch
13
+ import soundfile as sf
14
  from loguru import logger
15
  from PIL import Image
16
  from torch import Tensor
17
+ from scipy import signal
18
 
19
  from df import config
20
+ from df.enhance import enhance, init_df
21
  from df.io import resample
22
 
23
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
42
  }
43
 
44
 
45
+ @dataclass
46
+ class AudioMetaData:
47
+ """Simple audio metadata container to replace torchaudio.backend.common.AudioMetaData"""
48
+ sample_rate: int
49
+ num_frames: int
50
+ num_channels: int
51
+ bits_per_sample: int
52
+ encoding: str
53
+
54
+
55
+ def load_audio(file_path: str, sr: int) -> Tuple[Tensor, AudioMetaData]:
56
+ """Load audio file using soundfile and resample if necessary.
57
+
58
+ Args:
59
+ file_path: Path to audio file
60
+ sr: Target sample rate
61
+
62
+ Returns:
63
+ audio: Torch tensor of shape [channels, samples]
64
+ meta: AudioMetaData with file information
65
+ """
66
+ try:
67
+ # Read audio using soundfile
68
+ audio_np, sample_rate = sf.read(file_path, dtype='float32')
69
+
70
+ # Handle mono/stereo
71
+ if audio_np.ndim == 1:
72
+ audio_np = audio_np[np.newaxis, :] # Add channel dimension
73
+ num_channels = 1
74
+ else:
75
+ audio_np = audio_np.T # Convert [samples, channels] to [channels, samples]
76
+ num_channels = audio_np.shape[0]
77
+
78
+ # Get file info for metadata
79
+ info = sf.info(file_path)
80
+ num_frames = info.frames
81
+
82
+ # Create metadata
83
+ meta = AudioMetaData(
84
+ sample_rate=sample_rate,
85
+ num_frames=num_frames,
86
+ num_channels=num_channels,
87
+ bits_per_sample=-1, # Not directly available from soundfile
88
+ encoding=info.format
89
+ )
90
+
91
+ # Convert to torch tensor
92
+ audio = torch.from_numpy(audio_np).float()
93
+
94
+ # Resample if necessary
95
+ if sample_rate != sr:
96
+ audio = resample_audio(audio, sample_rate, sr)
97
+ meta.sample_rate = sr
98
+
99
+ return audio, meta
100
+
101
+ except Exception as e:
102
+ logger.error(f"Error loading audio file {file_path}: {e}")
103
+ raise
104
+
105
+
106
+ def save_audio(file_path: str, audio: Tensor, sr: int) -> None:
107
+ """Save audio tensor to file using soundfile.
108
+
109
+ Args:
110
+ file_path: Output file path
111
+ audio: Audio tensor of shape [channels, samples] or [samples]
112
+ sr: Sample rate
113
+ """
114
+ try:
115
+ # Convert tensor to numpy
116
+ audio_np = audio.cpu().numpy()
117
+
118
+ # Handle tensor shape
119
+ if audio_np.ndim == 3:
120
+ audio_np = audio_np.squeeze(0)
121
+
122
+ # Convert [channels, samples] to [samples, channels] for soundfile
123
+ if audio_np.ndim == 2:
124
+ audio_np = audio_np.T
125
+
126
+ # Ensure float32
127
+ audio_np = audio_np.astype(np.float32)
128
+
129
+ # Clip to valid range
130
+ audio_np = np.clip(audio_np, -1.0, 1.0)
131
+
132
+ # Save using soundfile
133
+ sf.write(file_path, audio_np, sr)
134
+ logger.info(f"Saved audio to {file_path}")
135
+
136
+ except Exception as e:
137
+ logger.error(f"Error saving audio to {file_path}: {e}")
138
+ raise
139
+
140
+
141
+ def resample_audio(audio: Tensor, sr_orig: int, sr_target: int) -> Tensor:
142
+ """Resample audio using scipy.signal.resample_poly.
143
+
144
+ Args:
145
+ audio: Audio tensor of shape [channels, samples]
146
+ sr_orig: Original sample rate
147
+ sr_target: Target sample rate
148
+
149
+ Returns:
150
+ Resampled audio tensor
151
+ """
152
+ if sr_orig == sr_target:
153
+ return audio
154
+
155
+ # Convert to numpy for resampling
156
+ audio_np = audio.cpu().numpy()
157
+
158
+ # Calculate gcd for polyphase resampling
159
+ from math import gcd
160
+ g = gcd(sr_orig, sr_target)
161
+ up = sr_target // g
162
+ down = sr_orig // g
163
+
164
+ logger.debug(f"Resampling from {sr_orig} to {sr_target} (up={up}, down={down})")
165
+
166
+ # Resample each channel
167
+ if audio_np.ndim == 2:
168
+ resampled = np.zeros((audio_np.shape[0], int(audio_np.shape[1] * sr_target / sr_orig)))
169
+ for ch in range(audio_np.shape[0]):
170
+ resampled[ch] = signal.resample_poly(audio_np[ch], up, down)
171
+ else:
172
+ resampled = signal.resample_poly(audio_np, up, down)
173
+
174
+ return torch.from_numpy(resampled).float()
175
+
176
+
177
  def mix_at_snr(clean, noise, snr, eps=1e-10):
178
  """Mix clean and noise signal at a given SNR.
179
 
 
201
  K = torch.sqrt((E_noise / E_speech) * 10 ** (snr / 10) + eps)
202
  noise = noise / K
203
  mixture = clean + noise
204
+ logger.debug(f"mixture: {mixture.shape}")
205
  assert torch.isfinite(mixture).all()
206
  max_m = mixture.abs().max()
207
  if max_m > 1:
 
213
  def load_audio_gradio(
214
  audio_or_file: Union[None, str, Tuple[int, np.ndarray]], sr: int
215
  ) -> Optional[Tuple[Tensor, AudioMetaData]]:
216
+ """Load audio from file or gradio microphone input.
217
+
218
+ Args:
219
+ audio_or_file: Path to audio file, tuple from gradio mic, or None
220
+ sr: Target sample rate
221
+
222
+ Returns:
223
+ Tuple of (audio tensor, metadata) or None
224
+ """
225
  if audio_or_file is None:
226
  return None
227
+
228
  if isinstance(audio_or_file, str):
229
  if audio_or_file.lower() == "none":
230
  return None
231
+ # Load from file path
232
  audio, meta = load_audio(audio_or_file, sr)
233
  else:
234
+ # Handle gradio microphone input
235
+ meta = AudioMetaData(
236
+ sample_rate=-1,
237
+ num_frames=-1,
238
+ num_channels=-1,
239
+ bits_per_sample=-1,
240
+ encoding=""
241
+ )
242
  assert isinstance(audio_or_file, (tuple, list))
243
+ sample_rate, audio_np = audio_or_file
244
+
245
+ # Gradio returns [samples, channels], reshape if needed
246
  audio_np = audio_np.reshape(audio_np.shape[0], -1).T
247
+
248
+ # Handle different integer formats
249
  if audio_np.dtype == np.int16:
250
  audio_np = (audio_np / (1 << 15)).astype(np.float32)
251
  elif audio_np.dtype == np.int32:
252
  audio_np = (audio_np / (1 << 31)).astype(np.float32)
253
+
254
+ audio = torch.from_numpy(audio_np).float()
255
+
256
+ # Resample if necessary
257
+ if sample_rate != sr:
258
+ audio = resample_audio(audio, sample_rate, sr)
259
+
260
+ meta.sample_rate = sr
261
+
262
  return audio, meta
263
 
264
 
265
  def demo_fn(speech_upl: str, noise_type: str, snr: int, mic_input: Optional[str] = None):
266
+ """Main demo function for audio denoising.
267
+
268
+ Args:
269
+ speech_upl: Path to uploaded speech file
270
+ noise_type: Type of noise to add
271
+ snr: Signal-to-noise ratio
272
+ mic_input: Path to microphone input file
273
+
274
+ Returns:
275
+ Tuple of (noisy_audio_path, noisy_spectrogram, enhanced_audio_path, enhanced_spectrogram)
276
+ """
277
  if mic_input:
278
  speech_upl = mic_input
279
+
280
  sr = config("sr", 48000, int, section="df")
281
  logger.info(f"Got parameters speech_upl: {speech_upl}, noise: {noise_type}, snr: {snr}")
282
  snr = int(snr)
283
  noise_fn = NOISES[noise_type]
284
  meta = AudioMetaData(-1, -1, -1, -1, "")
285
  max_s = 10 # limit to 10 seconds
286
+
287
  if speech_upl is not None:
288
  sample, meta = load_audio(speech_upl, sr)
289
  max_len = max_s * sr
 
293
  else:
294
  sample, meta = load_audio("samples/p232_013_clean.wav", sr)
295
  sample = sample[..., : max_s * sr]
296
+
297
  if sample.dim() > 1 and sample.shape[0] > 1:
298
  assert (
299
  sample.shape[1] > sample.shape[0]
300
  ), f"Expecting channels first, but got {sample.shape}"
301
  sample = sample.mean(dim=0, keepdim=True)
302
+
303
  logger.info(f"Loaded sample with shape {sample.shape}")
304
+
305
  if noise_fn is not None:
306
+ noise, _ = load_audio(noise_fn, sr)
307
  logger.info(f"Loaded noise with shape {noise.shape}")
308
  _, _, sample = mix_at_snr(sample, noise, snr)
309
+
310
  logger.info("Start denoising audio")
311
  enhanced = enhance(model, df, sample)
312
  logger.info("Denoising finished")
313
+
314
+ # Apply fade-in limiter
315
  lim = torch.linspace(0.0, 1.0, int(sr * 0.15)).unsqueeze(0)
316
  lim = torch.cat((lim, torch.ones(1, enhanced.shape[1] - lim.shape[1])), dim=1)
317
  enhanced = enhanced * lim
318
+
319
+ # Resample back to original sample rate if needed
320
  if meta.sample_rate != sr:
321
+ enhanced = resample_audio(enhanced, sr, meta.sample_rate)
322
+ sample = resample_audio(sample, sr, meta.sample_rate)
323
  sr = meta.sample_rate
324
+
325
+ # Save audio files
326
+ noisy_wav = tempfile.NamedTemporaryFile(suffix=".wav", delete=False).name
327
  save_audio(noisy_wav, sample, sr)
328
+
329
+ enhanced_wav = tempfile.NamedTemporaryFile(suffix=".wav", delete=False).name
330
  save_audio(enhanced_wav, enhanced, sr)
331
+
332
  logger.info(f"saved audios: {noisy_wav}, {enhanced_wav}")
333
+
334
+ # Generate spectrograms
335
  ax_noisy.clear()
336
  ax_enh.clear()
337
  noisy_im = spec_im(sample, sr=sr, figure=fig_noisy, ax=ax_noisy)
338
  enh_im = spec_im(enhanced, sr=sr, figure=fig_enh, ax=ax_enh)
339
+
340
+ # Cleanup temporary files (except the ones we want to return)
341
  filter = [speech_upl, noisy_wav, enhanced_wav]
342
  if mic_input is not None and mic_input != "":
343
  filter.append(mic_input)
344
  cleanup_tmp(filter)
345
+
346
  return noisy_wav, noisy_im, enhanced_wav, enh_im
347
 
348
 
 
378
  set_ylabel = plt.ylabel
379
  set_xlim = plt.xlim
380
  set_ylim = plt.ylim
381
+
382
  if n_fft is None:
383
  if spec.shape[0] % 2 == 0:
384
  n_fft = spec.shape[0] * 2
385
  else:
386
  n_fft = (spec.shape[0] - 1) * 2
387
+
388
  hop = hop or n_fft // 4
389
+
390
  if t is None:
391
  t = np.arange(0, spec_np.shape[-1]) * hop / sr
392
+
393
  if f is None:
394
  f = np.arange(0, spec_np.shape[0]) * sr // 2 / (n_fft // 2) / 1000
395
+
396
  im = ax.pcolormesh(
397
  t, f, spec_np, rasterized=True, shading="auto", vmin=vmin, vmax=vmax, cmap=cmap
398
  )
399
+
400
  if title is not None:
401
  set_title(title)
402
  if xlabel is not None:
 
407
  set_xlim(xlim)
408
  if ylim is not None:
409
  set_ylim(ylim)
410
+
411
  return im
412
 
413
 
 
420
  labels=True,
421
  **kwargs,
422
  ) -> Image:
423
+ """Convert audio to spectrogram image.
424
+
425
+ Args:
426
+ audio: Audio tensor
427
+ figsize: Figure size
428
+ colorbar: Whether to show colorbar
429
+ colorbar_format: Format for colorbar
430
+ figure: Matplotlib figure to use
431
+ labels: Whether to show axis labels
432
+ **kwargs: Additional arguments for specshow
433
+
434
+ Returns:
435
+ PIL Image of the spectrogram
436
+ """
437
  audio = torch.as_tensor(audio)
438
+
439
  if labels:
440
  kwargs.setdefault("xlabel", "Time [s]")
441
  kwargs.setdefault("ylabel", "Frequency [Hz]")
442
+
443
  n_fft = kwargs.setdefault("n_fft", 1024)
444
  hop = kwargs.setdefault("hop", 512)
445
+
446
  w = torch.hann_window(n_fft, device=audio.device)
447
  spec = torch.stft(audio, n_fft, hop, window=w, return_complex=False)
448
  spec = spec.div_(w.pow(2).sum())
449
  spec = torch.view_as_complex(spec).abs().clamp_min(1e-12).log10().mul(10)
450
  kwargs.setdefault("vmax", max(0.0, spec.max().item()))
451
+
452
  if figure is None:
453
  figure = plt.figure(figsize=figsize)
454
  figure.set_tight_layout(True)
455
+
456
  if spec.dim() > 2:
457
  spec = spec.squeeze(0)
458
+
459
  im = specshow(spec, **kwargs)
460
+
461
  if colorbar:
462
  ckwargs = {}
463
  if "ax" in kwargs:
 
466
  colorbar_format = "%+2.0f dB"
467
  ckwargs = {"ax": kwargs["ax"]}
468
  plt.colorbar(im, format=colorbar_format, **ckwargs)
469
+
470
  figure.canvas.draw()
471
  return Image.frombytes("RGB", figure.canvas.get_width_height(), figure.canvas.tostring_rgb())
472
 
473
 
474
  def cleanup_tmp(filter: List[str] = [], hours_keep=2):
475
+ """Clean up old temporary files.
476
+
477
+ Args:
478
+ filter: List of file paths to keep (not delete)
479
+ hours_keep: Number of hours to keep files
480
+ """
481
  filter.append("p232")
482
  logger.info(f"Filter: {filter}")
483
+
484
  # Cleanup some old wav files
485
  if os.path.exists("/tmp"):
486
  for f in glob.glob("/tmp/*"):
 
496
 
497
 
498
  def toggle(choice):
499
+ """Toggle between microphone and file input.
500
+
501
+ Args:
502
+ choice: "mic" or "file"
503
+
504
+ Returns:
505
+ Tuple of updated components visibility
506
+ """
507
  if choice == "mic":
508
  return gr.update(visible=True, value=None), gr.update(visible=False, value=None)
509
  else:
510
  return gr.update(visible=False, value=None), gr.update(visible=True, value=None)
511
 
512
 
513
+ # Create Gradio interface
514
  with gr.Blocks() as demo:
515
  with gr.Row():
516
  gr.Markdown(
517
  """
518
+ ## DeepFilterNet2 Demo
519
 
520
  This demo denoises audio files using DeepFilterNet. Try it with your own voice!
521
  """
522
  )
523
+
524
  with gr.Row():
525
  with gr.Column():
526
  radio = gr.Radio(
 
543
  mic_input,
544
  ]
545
  btn = gr.Button("Generate")
546
+
547
  with gr.Column():
548
  outputs = [
 
549
  gr.Audio(type="filepath", label="Noisy audio"),
550
  gr.Image(label="Noisy spectrogram"),
 
551
  gr.Audio(type="filepath", label="Enhanced audio"),
552
  gr.Image(label="Enhanced spectrogram"),
553
  ]
554
+
555
  btn.click(fn=demo_fn, inputs=inputs, outputs=outputs, api_name='denoise')
556
  radio.change(toggle, radio, [mic_input, audio_file])
557
+
558
  gr.Examples(
559
  [
560
  ["./samples/p232_013_clean.wav", "Kitchen", "10"],
 
566
  inputs=inputs,
567
  outputs=outputs,
568
  cache_examples=True,
569
+ )
570
+
571
  gr.Markdown(open("usage.md").read())
572
 
573
  cleanup_tmp()
574
+ demo.launch(enable_queue=True)