humairawan commited on
Commit
cf1b40e
Β·
verified Β·
1 Parent(s): fb24e3e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +661 -81
app.py CHANGED
@@ -29,124 +29,281 @@ from df.io import resample
29
  class AppConfig:
30
  """Application configuration"""
31
  device: torch.device
32
- model_sample_rate: int = 48000
33
  max_duration_seconds: int = 3600
34
  cleanup_hours: int = 2
35
  temp_dir: str = "/tmp"
36
  model_path: str = "./DeepFilterNet2"
37
  fade_duration: float = 0.15
38
-
39
- # ============================================================================
40
- # Audio Processing Classes
41
- # ============================================================================
42
 
43
  class AudioProcessor:
 
 
44
  def __init__(self, model, df, config: AppConfig):
45
  self.model = model
46
  self.df = df
47
  self.config = config
48
-
49
  def mix_at_snr(self, clean: Tensor, noise: Tensor, snr: float, eps: float = 1e-10) -> Tuple[Tensor, Tensor, Tensor]:
 
50
  clean = torch.as_tensor(clean).mean(0, keepdim=True)
51
  noise = torch.as_tensor(noise).mean(0, keepdim=True)
 
52
  if noise.shape[1] < clean.shape[1]:
53
  repeats = int(math.ceil(clean.shape[1] / noise.shape[1]))
54
  noise = noise.repeat((1, repeats))
 
55
  max_start = int(noise.shape[1] - clean.shape[1])
56
  start = torch.randint(0, max_start, ()).item() if max_start > 0 else 0
57
- noise = noise[:, start:start + clean.shape[1]]
 
58
  E_speech = torch.mean(clean.pow(2)) + eps
59
  E_noise = torch.mean(noise.pow(2)) + eps
60
  K = torch.sqrt((E_noise / E_speech) * 10 ** (snr / 10) + eps)
61
  noise = noise / K
62
  mixture = clean + noise
 
 
63
  max_m = mixture.abs().max()
64
  if max_m > 1:
 
65
  clean, noise, mixture = clean / max_m, noise / max_m, mixture / max_m
 
66
  return clean, noise, mixture
67
-
68
  def enhance_audio(self, audio: Tensor) -> Tensor:
 
 
69
  with torch.no_grad():
70
  enhanced = enhance(self.model, self.df, audio)
71
- sr = self.config.model_sample_rate
 
72
  fade_samples = int(sr * self.config.fade_duration)
73
  lim = torch.linspace(0.0, 1.0, fade_samples).unsqueeze(0)
74
- lim = torch.cat((lim, torch.ones(1, enhanced.shape[1] - fade_samples)), dim=1)
75
- return enhanced * lim
 
 
 
76
 
77
  class AudioLoader:
 
 
78
  @staticmethod
79
  def ensure_wav(filepath: str) -> str:
 
80
  if not filepath:
81
  return filepath
82
- ext = Path(filepath).suffix.lower()
83
- if ext in [".mp3", ".m4a", ".ogg", ".flac", ".aac"]:
84
- wav_path = str(Path(filepath).with_suffix(".wav"))
85
- subprocess.run(["ffmpeg", "-y", "-i", filepath, "-acodec", "pcm_s16le", wav_path],
86
- check=True, capture_output=True)
87
- return wav_path
 
 
 
 
 
 
 
 
 
88
  return filepath
89
-
90
  @staticmethod
91
- def load_and_resample(audio_or_file: Union[str, Tuple[int, np.ndarray]], target_sr: int) -> Tensor:
 
 
 
 
 
 
 
92
  if isinstance(audio_or_file, str):
 
 
93
  audio_or_file = AudioLoader.ensure_wav(audio_or_file)
94
- audio, meta = load_audio(audio_or_file, target_sr)
95
  else:
96
- sr, audio_np = audio_or_file
 
 
 
97
  audio_np = audio_np.reshape(audio_np.shape[0], -1).T
 
98
  if audio_np.dtype == np.int16:
99
  audio_np = (audio_np / (1 << 15)).astype(np.float32)
100
  elif audio_np.dtype == np.int32:
101
  audio_np = (audio_np / (1 << 31)).astype(np.float32)
102
- audio = torch.from_numpy(audio_np)
103
- if sr != target_sr:
104
- audio = resample(audio, target_sr, sr)
105
- return audio
 
106
 
107
  class SpectrogramVisualizer:
108
- def __init__(self, figsize=(15,4)):
 
 
109
  self.figsize = figsize
 
110
  self.fig_noisy, self.ax_noisy = plt.subplots(figsize=figsize)
 
111
  self.fig_enh, self.ax_enh = plt.subplots(figsize=figsize)
112
-
113
- def create_spectrogram(self, audio: Tensor, figure: plt.Figure, ax: plt.Axes,
114
- sr: int, n_fft: int = 1024, hop: int = 512, title: str = None) -> PILImage.Image:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
115
  audio = torch.as_tensor(audio)
 
116
  w = torch.hann_window(n_fft, device=audio.device)
117
  spec = torch.stft(audio, n_fft, hop, window=w, return_complex=False)
118
  spec = spec.div_(w.pow(2).sum())
119
  spec = torch.view_as_complex(spec).abs().clamp_min(1e-12).log10().mul(10)
 
 
 
120
  if spec.dim() > 2:
121
  spec = spec.squeeze(0)
 
122
  ax.clear()
123
- t = np.arange(spec.shape[-1]) * hop / sr
124
- f = np.arange(spec.shape[0]) * sr // 2 / (n_fft // 2) / 1000
125
- ax.pcolormesh(t, f, spec.cpu().numpy(), shading="auto", cmap="inferno", vmin=-100, vmax=0)
126
- if title:
127
- ax.set_title(title)
 
 
 
 
 
 
 
 
 
128
  figure.canvas.draw()
129
- return PILImage.frombytes("RGB", figure.canvas.get_width_height(), figure.canvas.tostring_rgb())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
130
 
131
  # ============================================================================
132
- # Initialization
133
  # ============================================================================
134
 
135
- app_config = AppConfig(device=torch.device("cuda" if torch.cuda.is_available() else "cpu"))
 
 
 
 
136
  model, df, _ = init_df(app_config.model_path, config_allow_defaults=True)
137
  model = model.to(device=app_config.device).eval()
 
138
  audio_processor = AudioProcessor(model, df, app_config)
139
  audio_loader = AudioLoader()
140
  visualizer = SpectrogramVisualizer()
 
141
 
142
  NOISES = {
143
  "None": None,
144
- "Kitchen": "samples/dkitchen.wav",
145
- "Living Room": "samples/dliving.wav",
146
- "River": "samples/nriver.wav",
147
- "Cafe": "samples/scafe.wav",
148
  }
149
 
 
150
  # ============================================================================
151
  # Main Processing Function
152
  # ============================================================================
@@ -155,61 +312,484 @@ def process_audio(
155
  speech_file: Optional[str],
156
  noise_type: str,
157
  snr: int,
158
- target_rate: int = 22050,
159
  mic_input: Optional[str] = None,
160
  ) -> Tuple[str, PILImage.Image, str, PILImage.Image]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
161
 
162
- if mic_input:
163
- speech_file = mic_input
164
- model_sr = app_config.model_sample_rate
165
- audio = audio_loader.load_and_resample(speech_file, model_sr)
 
 
 
 
 
166
 
167
- # Add noise if requested
168
- noise_fn = NOISES.get(noise_type)
169
- if noise_fn:
170
- noise_audio = audio_loader.load_and_resample(noise_fn, model_sr)
171
- _, _, audio = audio_processor.mix_at_snr(audio, noise_audio, snr)
172
 
173
- enhanced = audio_processor.enhance_audio(audio)
 
 
 
174
 
175
- # Downsample back to target rate if needed
176
- if target_rate != model_sr:
177
- enhanced = resample(enhanced, target_rate, model_sr)
178
- audio = resample(audio, target_rate, model_sr)
 
 
 
179
 
180
- noisy_wav = tempfile.NamedTemporaryFile(suffix="_noisy.wav", delete=False).name
181
- enhanced_wav = tempfile.NamedTemporaryFile(suffix="_enhanced.wav", delete=False).name
182
- save_audio(noisy_wav, audio, target_rate)
183
- save_audio(enhanced_wav, enhanced, target_rate)
 
 
 
 
184
 
185
- noisy_spec = visualizer.create_spectrogram(audio, visualizer.fig_noisy, visualizer.ax_noisy,
186
- sr=target_rate, title="Noisy Audio")
187
- enhanced_spec = visualizer.create_spectrogram(enhanced, visualizer.fig_enh, visualizer.ax_enh,
188
- sr=target_rate, title="Enhanced Audio")
189
- return noisy_wav, noisy_spec, enhanced_wav, enhanced_spec
190
 
191
  # ============================================================================
192
  # Gradio Interface
193
  # ============================================================================
194
 
195
- with gr.Blocks() as demo:
196
- gr.Markdown("# 🎡 DeepFilterNet2 Denoiser with Resampling Support")
197
- audio_file = gr.Audio(type="filepath", label="Upload Audio")
198
- mic_input = gr.Audio(sources=["microphone"], type="filepath", label="Record Audio")
199
- noise_type = gr.Dropdown(label="Noise Type", choices=list(NOISES.keys()), value="None")
200
- snr = gr.Slider(label="SNR (dB)", minimum=-10, maximum=30, step=1, value=10)
201
- target_rate = gr.Dropdown(label="Output Sample Rate", choices=[16000, 22050, 44100, 48000], value=22050)
202
- process_btn = gr.Button("πŸš€ Enhance Audio")
203
- noisy_audio = gr.Audio(type="filepath")
204
- noisy_spec = gr.Image()
205
- enhanced_audio = gr.Audio(type="filepath")
206
- enhanced_spec = gr.Image()
207
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
208
  process_btn.click(
209
  fn=process_audio,
210
- inputs=[audio_file, noise_type, snr, target_rate, mic_input],
211
- outputs=[noisy_audio, noisy_spec, enhanced_audio, enhanced_spec]
 
212
  )
 
 
 
 
 
 
 
 
 
213
 
 
214
  if __name__ == "__main__":
215
- demo.launch()
 
 
 
 
 
29
  class AppConfig:
30
  """Application configuration"""
31
  device: torch.device
32
+ sample_rate: int = 48000
33
  max_duration_seconds: int = 3600
34
  cleanup_hours: int = 2
35
  temp_dir: str = "/tmp"
36
  model_path: str = "./DeepFilterNet2"
37
  fade_duration: float = 0.15
38
+
 
 
 
39
 
40
  class AudioProcessor:
41
+ """Handles audio processing operations"""
42
+
43
  def __init__(self, model, df, config: AppConfig):
44
  self.model = model
45
  self.df = df
46
  self.config = config
47
+
48
  def mix_at_snr(self, clean: Tensor, noise: Tensor, snr: float, eps: float = 1e-10) -> Tuple[Tensor, Tensor, Tensor]:
49
+ """Mix clean and noise signal at a given SNR with improved error handling."""
50
  clean = torch.as_tensor(clean).mean(0, keepdim=True)
51
  noise = torch.as_tensor(noise).mean(0, keepdim=True)
52
+
53
  if noise.shape[1] < clean.shape[1]:
54
  repeats = int(math.ceil(clean.shape[1] / noise.shape[1]))
55
  noise = noise.repeat((1, repeats))
56
+
57
  max_start = int(noise.shape[1] - clean.shape[1])
58
  start = torch.randint(0, max_start, ()).item() if max_start > 0 else 0
59
+ noise = noise[:, start : start + clean.shape[1]]
60
+
61
  E_speech = torch.mean(clean.pow(2)) + eps
62
  E_noise = torch.mean(noise.pow(2)) + eps
63
  K = torch.sqrt((E_noise / E_speech) * 10 ** (snr / 10) + eps)
64
  noise = noise / K
65
  mixture = clean + noise
66
+
67
+ assert torch.isfinite(mixture).all(), "Non-finite values detected in mixture"
68
  max_m = mixture.abs().max()
69
  if max_m > 1:
70
+ logger.warning(f"Clipping detected during mixing. Reducing gain by {1/max_m:.3f}")
71
  clean, noise, mixture = clean / max_m, noise / max_m, mixture / max_m
72
+
73
  return clean, noise, mixture
74
+
75
  def enhance_audio(self, audio: Tensor) -> Tensor:
76
+ """Enhance audio using the DeepFilterNet model."""
77
+ logger.info(f"Enhancing audio with shape {audio.shape}")
78
  with torch.no_grad():
79
  enhanced = enhance(self.model, self.df, audio)
80
+
81
+ sr = self.config.sample_rate
82
  fade_samples = int(sr * self.config.fade_duration)
83
  lim = torch.linspace(0.0, 1.0, fade_samples).unsqueeze(0)
84
+ lim = torch.cat((lim, torch.ones(1, enhanced.shape[1] - lim.shape[1])), dim=1)
85
+ enhanced = enhanced * lim
86
+
87
+ return enhanced
88
+
89
 
90
  class AudioLoader:
91
+ """Handles audio loading from various sources"""
92
+
93
  @staticmethod
94
  def ensure_wav(filepath: str) -> str:
95
+ """Convert audio files to WAV using ffmpeg if needed."""
96
  if not filepath:
97
  return filepath
98
+
99
+ file_ext = Path(filepath).suffix.lower()
100
+ if file_ext in ['.mp3', '.m4a', '.ogg', '.flac', '.aac']:
101
+ wav_path = str(Path(filepath).with_suffix('.wav'))
102
+ try:
103
+ subprocess.run(
104
+ ["ffmpeg", "-y", "-i", filepath, "-acodec", "pcm_s16le", wav_path],
105
+ check=True,
106
+ capture_output=True
107
+ )
108
+ logger.info(f"Converted {file_ext} to WAV: {wav_path}")
109
+ return wav_path
110
+ except subprocess.CalledProcessError as e:
111
+ logger.error(f"FFmpeg conversion failed: {e.stderr}")
112
+ raise
113
  return filepath
114
+
115
  @staticmethod
116
+ def load_audio_gradio(
117
+ audio_or_file: Union[None, str, Tuple[int, np.ndarray]],
118
+ sr: int
119
+ ) -> Optional[Tuple[Tensor, AudioMetaData]]:
120
+ """Load audio from Gradio input."""
121
+ if audio_or_file is None:
122
+ return None
123
+
124
  if isinstance(audio_or_file, str):
125
+ if audio_or_file.lower() == "none":
126
+ return None
127
  audio_or_file = AudioLoader.ensure_wav(audio_or_file)
128
+ audio, meta = load_audio(audio_or_file, sr)
129
  else:
130
+ meta = AudioMetaData(-1, -1, -1, -1, "")
131
+ assert isinstance(audio_or_file, (tuple, list))
132
+ meta.sample_rate, audio_np = audio_or_file
133
+
134
  audio_np = audio_np.reshape(audio_np.shape[0], -1).T
135
+
136
  if audio_np.dtype == np.int16:
137
  audio_np = (audio_np / (1 << 15)).astype(np.float32)
138
  elif audio_np.dtype == np.int32:
139
  audio_np = (audio_np / (1 << 31)).astype(np.float32)
140
+
141
+ audio = resample(torch.from_numpy(audio_np), meta.sample_rate, sr)
142
+
143
+ return audio, meta
144
+
145
 
146
  class SpectrogramVisualizer:
147
+ """Handles spectrogram visualization"""
148
+
149
+ def __init__(self, figsize: Tuple[float, float] = (15.2, 4)):
150
  self.figsize = figsize
151
+ plt.style.use('dark_background')
152
  self.fig_noisy, self.ax_noisy = plt.subplots(figsize=figsize)
153
+ self.fig_noisy.set_tight_layout(True)
154
  self.fig_enh, self.ax_enh = plt.subplots(figsize=figsize)
155
+ self.fig_enh.set_tight_layout(True)
156
+
157
+ def specshow(
158
+ self,
159
+ spec: Union[Tensor, np.ndarray],
160
+ ax: Optional[plt.Axes] = None,
161
+ title: Optional[str] = None,
162
+ xlabel: Optional[str] = None,
163
+ ylabel: Optional[str] = None,
164
+ sr: int = 48000,
165
+ n_fft: Optional[int] = None,
166
+ hop: Optional[int] = None,
167
+ vmin: float = -100,
168
+ vmax: float = 0,
169
+ cmap: str = "viridis",
170
+ ):
171
+ """Plot a spectrogram of shape [F, T]"""
172
+ spec_np = spec.cpu().numpy() if isinstance(spec, torch.Tensor) else spec
173
+
174
+ if n_fft is None:
175
+ n_fft = spec.shape[0] * 2 if spec.shape[0] % 2 == 0 else (spec.shape[0] - 1) * 2
176
+ hop = hop or n_fft // 4
177
+
178
+ t = np.arange(0, spec_np.shape[-1]) * hop / sr
179
+ f = np.arange(0, spec_np.shape[0]) * sr // 2 / (n_fft // 2) / 1000
180
+
181
+ im = ax.pcolormesh(
182
+ t, f, spec_np,
183
+ rasterized=True,
184
+ shading="auto",
185
+ vmin=vmin,
186
+ vmax=vmax,
187
+ cmap=cmap
188
+ )
189
+
190
+ if title:
191
+ ax.set_title(title, fontsize=14, fontweight='bold', pad=15, color='white')
192
+ if xlabel:
193
+ ax.set_xlabel(xlabel, fontsize=11, color='white')
194
+ if ylabel:
195
+ ax.set_ylabel(ylabel, fontsize=11, color='white')
196
+
197
+ ax.grid(True, alpha=0.15, linestyle='--', linewidth=0.5)
198
+ ax.tick_params(colors='white', labelsize=9)
199
+
200
+ return im
201
+
202
+ def create_spectrogram(
203
+ self,
204
+ audio: Tensor,
205
+ figure: plt.Figure,
206
+ ax: plt.Axes,
207
+ sr: int = 48000,
208
+ n_fft: int = 1024,
209
+ hop: int = 512,
210
+ title: Optional[str] = None,
211
+ ) -> PILImage.Image:
212
+ """Create spectrogram image from audio tensor"""
213
  audio = torch.as_tensor(audio)
214
+
215
  w = torch.hann_window(n_fft, device=audio.device)
216
  spec = torch.stft(audio, n_fft, hop, window=w, return_complex=False)
217
  spec = spec.div_(w.pow(2).sum())
218
  spec = torch.view_as_complex(spec).abs().clamp_min(1e-12).log10().mul(10)
219
+
220
+ vmax = max(0.0, spec.max().item())
221
+
222
  if spec.dim() > 2:
223
  spec = spec.squeeze(0)
224
+
225
  ax.clear()
226
+ self.specshow(
227
+ spec,
228
+ ax=ax,
229
+ title=title,
230
+ xlabel="Time [s]",
231
+ ylabel="Frequency [kHz]",
232
+ sr=sr,
233
+ n_fft=n_fft,
234
+ hop=hop,
235
+ vmax=vmax,
236
+ )
237
+
238
+ figure.patch.set_facecolor('#0f0f0f')
239
+ ax.set_facecolor('#0f0f0f')
240
  figure.canvas.draw()
241
+
242
+ return PILImage.frombytes(
243
+ "RGB",
244
+ figure.canvas.get_width_height(),
245
+ figure.canvas.tostring_rgb()
246
+ )
247
+
248
+
249
+ class FileManager:
250
+ """Manages temporary file cleanup"""
251
+
252
+ @staticmethod
253
+ def cleanup_tmp(filter_list: List[str] = None, hours_keep: int = 2, temp_dir: str = "/tmp"):
254
+ """Clean up old temporary files."""
255
+ if filter_list is None:
256
+ filter_list = []
257
+ filter_list.append("p232")
258
+
259
+ if not os.path.exists(temp_dir):
260
+ return
261
+
262
+ logger.info(f"Cleaning up temporary files older than {hours_keep} hours")
263
+ cleaned = 0
264
+
265
+ for filepath in glob.glob(os.path.join(temp_dir, "*")):
266
+ try:
267
+ is_old = (time.time() - os.path.getmtime(filepath)) / 3600 > hours_keep
268
+ filtered = any(filt in filepath for filt in filter_list if filt is not None)
269
+
270
+ if is_old and not filtered:
271
+ os.remove(filepath)
272
+ cleaned += 1
273
+ logger.debug(f"Removed file {filepath}")
274
+ except Exception as e:
275
+ logger.warning(f"Failed to remove file {filepath}: {e}")
276
+
277
+ if cleaned > 0:
278
+ logger.info(f"Cleaned up {cleaned} temporary files")
279
+
280
 
281
  # ============================================================================
282
+ # Initialize Application
283
  # ============================================================================
284
 
285
+ app_config = AppConfig(
286
+ device=torch.device("cuda" if torch.cuda.is_available() else "cpu")
287
+ )
288
+
289
+ logger.info(f"Loading DeepFilterNet2 model on {app_config.device}")
290
  model, df, _ = init_df(app_config.model_path, config_allow_defaults=True)
291
  model = model.to(device=app_config.device).eval()
292
+
293
  audio_processor = AudioProcessor(model, df, app_config)
294
  audio_loader = AudioLoader()
295
  visualizer = SpectrogramVisualizer()
296
+ file_manager = FileManager()
297
 
298
  NOISES = {
299
  "None": None,
300
+ "🍳 Kitchen": "samples/dkitchen.wav",
301
+ "πŸ›‹οΈ Living Room": "samples/dliving.wav",
302
+ "🌊 River": "samples/nriver.wav",
303
+ "β˜• Cafe": "samples/scafe.wav",
304
  }
305
 
306
+
307
  # ============================================================================
308
  # Main Processing Function
309
  # ============================================================================
 
312
  speech_file: Optional[str],
313
  noise_type: str,
314
  snr: int,
 
315
  mic_input: Optional[str] = None,
316
  ) -> Tuple[str, PILImage.Image, str, PILImage.Image]:
317
+ """Main audio processing pipeline."""
318
+ try:
319
+ if mic_input:
320
+ speech_file = mic_input
321
+
322
+ sr = app_config.sample_rate
323
+ logger.info(f"Processing: file={speech_file}, noise={noise_type}, snr={snr}")
324
+
325
+ if speech_file is not None:
326
+ speech_file = audio_loader.ensure_wav(speech_file)
327
+ sample, meta = load_audio(speech_file, sr)
328
+
329
+ max_len = app_config.max_duration_seconds * sr
330
+ if sample.shape[-1] > max_len:
331
+ logger.warning(f"Audio too long, truncating to {app_config.max_duration_seconds}s")
332
+ start = torch.randint(0, sample.shape[-1] - max_len, ()).item()
333
+ sample = sample[..., start : start + max_len]
334
+ else:
335
+ sample, meta = load_audio("samples/p232_013_clean.wav", sr)
336
+ sample = sample[..., : app_config.max_duration_seconds * sr]
337
+
338
+ if sample.dim() > 1 and sample.shape[0] > 1:
339
+ logger.info(f"Converting from {sample.shape[0]} channels to mono")
340
+ sample = sample.mean(dim=0, keepdim=True)
341
+
342
+ logger.info(f"Loaded audio with shape {sample.shape}")
343
+
344
+ noise_fn = NOISES.get(noise_type)
345
+ if noise_fn is not None:
346
+ noise, _ = load_audio(noise_fn, sr)
347
+ logger.info(f"Adding {noise_type} noise at {snr} dB SNR")
348
+ _, _, sample = audio_processor.mix_at_snr(sample, noise, int(snr))
349
+
350
+ enhanced = audio_processor.enhance_audio(sample)
351
+ logger.info("Audio enhancement completed")
352
+
353
+ if meta.sample_rate != sr and meta.sample_rate > 0:
354
+ enhanced = resample(enhanced, sr, meta.sample_rate)
355
+ sample = resample(sample, sr, meta.sample_rate)
356
+ sr = meta.sample_rate
357
+
358
+ noisy_wav = tempfile.NamedTemporaryFile(suffix="_noisy.wav", delete=False).name
359
+ save_audio(noisy_wav, sample, sr)
360
+
361
+ enhanced_wav = tempfile.NamedTemporaryFile(suffix="_enhanced.wav", delete=False).name
362
+ save_audio(enhanced_wav, enhanced, sr)
363
+
364
+ logger.info(f"Saved outputs: {noisy_wav}, {enhanced_wav}")
365
+
366
+ noisy_spec = visualizer.create_spectrogram(
367
+ sample,
368
+ visualizer.fig_noisy,
369
+ visualizer.ax_noisy,
370
+ sr=sr,
371
+ title="Input Audio Spectrogram"
372
+ )
373
+
374
+ enhanced_spec = visualizer.create_spectrogram(
375
+ enhanced,
376
+ visualizer.fig_enh,
377
+ visualizer.ax_enh,
378
+ sr=sr,
379
+ title="Enhanced Audio Spectrogram"
380
+ )
381
+
382
+ filter_files = [speech_file, noisy_wav, enhanced_wav]
383
+ if mic_input:
384
+ filter_files.append(mic_input)
385
+ file_manager.cleanup_tmp(filter_files, app_config.cleanup_hours)
386
+
387
+ return noisy_wav, noisy_spec, enhanced_wav, enhanced_spec
388
+
389
+ except Exception as e:
390
+ logger.error(f"Error processing audio: {e}", exc_info=True)
391
+ raise gr.Error(f"Processing failed: {str(e)}")
392
+
393
+
394
+ def toggle_input_mode(choice: str):
395
+ """Toggle between microphone and file upload."""
396
+ if choice == "mic":
397
+ return gr.update(visible=True, value=None), gr.update(visible=False, value=None)
398
+ else:
399
+ return gr.update(visible=False, value=None), gr.update(visible=True, value=None)
400
+
401
+
402
+ # ============================================================================
403
+ # Custom CSS
404
+ # ============================================================================
405
+
406
+ custom_css = """
407
+ /* Global Styles */
408
+ .gradio-container {
409
+ font-family: 'Inter', -apple-system, BlinkMacSystemFont, sans-serif !important;
410
+ }
411
+
412
+ /* Hero Section */
413
+ #hero-section {
414
+ background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
415
+ padding: 50px 30px;
416
+ border-radius: 20px;
417
+ margin-bottom: 40px;
418
+ box-shadow: 0 15px 40px rgba(102, 126, 234, 0.4);
419
+ text-align: center;
420
+ }
421
+
422
+ #hero-section h1 {
423
+ color: white;
424
+ font-size: 3.2em;
425
+ font-weight: 800;
426
+ margin: 0 0 15px 0;
427
+ text-shadow: 2px 2px 8px rgba(0,0,0,0.2);
428
+ letter-spacing: -1px;
429
+ }
430
+
431
+ #hero-section p {
432
+ color: rgba(255,255,255,0.95);
433
+ font-size: 1.25em;
434
+ margin: 10px auto;
435
+ max-width: 800px;
436
+ line-height: 1.6;
437
+ font-weight: 300;
438
+ }
439
+
440
+ /* Feature Cards */
441
+ .feature-card {
442
+ background: linear-gradient(135deg, #f5f7fa 0%, #c3cfe2 100%);
443
+ padding: 25px;
444
+ border-radius: 15px;
445
+ box-shadow: 0 4px 15px rgba(0,0,0,0.08);
446
+ margin-bottom: 20px;
447
+ border: 1px solid rgba(255,255,255,0.5);
448
+ transition: all 0.3s ease;
449
+ }
450
+
451
+ .feature-card:hover {
452
+ transform: translateY(-3px);
453
+ box-shadow: 0 8px 25px rgba(0,0,0,0.12);
454
+ }
455
+
456
+ /* Input Controls Section */
457
+ .input-controls {
458
+ background: linear-gradient(135deg, #a8edea 0%, #fed6e3 100%);
459
+ padding: 30px;
460
+ border-radius: 15px;
461
+ box-shadow: 0 5px 20px rgba(0,0,0,0.1);
462
+ }
463
+
464
+ /* Output Section */
465
+ .output-section {
466
+ background: linear-gradient(135deg, #ffecd2 0%, #fcb69f 100%);
467
+ padding: 30px;
468
+ border-radius: 15px;
469
+ box-shadow: 0 5px 20px rgba(0,0,0,0.1);
470
+ }
471
+
472
+ /* Section Headers */
473
+ .section-header {
474
+ color: #667eea;
475
+ font-size: 1.8em;
476
+ font-weight: 700;
477
+ margin: 30px 0 20px 0;
478
+ text-align: center;
479
+ background: linear-gradient(135deg, #667eea, #764ba2);
480
+ -webkit-background-clip: text;
481
+ -webkit-text-fill-color: transparent;
482
+ background-clip: text;
483
+ }
484
+
485
+ /* Process Button */
486
+ .process-button {
487
+ background: linear-gradient(135deg, #667eea 0%, #764ba2 100%) !important;
488
+ border: none !important;
489
+ font-size: 1.4em !important;
490
+ font-weight: 700 !important;
491
+ padding: 20px 50px !important;
492
+ border-radius: 50px !important;
493
+ box-shadow: 0 10px 30px rgba(102, 126, 234, 0.5) !important;
494
+ transition: all 0.3s ease !important;
495
+ color: white !important;
496
+ text-transform: uppercase;
497
+ letter-spacing: 1px;
498
+ }
499
+
500
+ .process-button:hover {
501
+ transform: translateY(-3px) scale(1.02) !important;
502
+ box-shadow: 0 15px 40px rgba(102, 126, 234, 0.7) !important;
503
+ }
504
+
505
+ /* Audio Components */
506
+ .audio-wrapper {
507
+ background: white;
508
+ padding: 20px;
509
+ border-radius: 12px;
510
+ box-shadow: 0 3px 12px rgba(0,0,0,0.08);
511
+ margin: 15px 0;
512
+ }
513
+
514
+ /* Tabs */
515
+ .tab-nav button {
516
+ font-weight: 600 !important;
517
+ font-size: 1.1em !important;
518
+ padding: 12px 24px !important;
519
+ border-radius: 10px 10px 0 0 !important;
520
+ }
521
+
522
+ .tab-nav button[aria-selected="true"] {
523
+ background: linear-gradient(135deg, #667eea, #764ba2) !important;
524
+ color: white !important;
525
+ }
526
+
527
+ /* Info Box */
528
+ .info-box {
529
+ background: linear-gradient(135deg, #e0c3fc 0%, #8ec5fc 100%);
530
+ padding: 25px;
531
+ border-radius: 15px;
532
+ margin: 25px 0;
533
+ border-left: 5px solid #667eea;
534
+ box-shadow: 0 4px 15px rgba(0,0,0,0.1);
535
+ }
536
+
537
+ .info-box h3 {
538
+ color: #667eea;
539
+ font-size: 1.4em;
540
+ font-weight: 700;
541
+ margin-top: 0;
542
+ }
543
+
544
+ .info-box ul {
545
+ margin: 10px 0;
546
+ padding-left: 25px;
547
+ }
548
+
549
+ .info-box li {
550
+ margin: 8px 0;
551
+ line-height: 1.6;
552
+ }
553
+
554
+ /* Examples Section */
555
+ .examples-section {
556
+ background: linear-gradient(135deg, #ffeaa7 0%, #dfe6e9 100%);
557
+ padding: 25px;
558
+ border-radius: 15px;
559
+ margin-top: 30px;
560
+ box-shadow: 0 4px 15px rgba(0,0,0,0.08);
561
+ }
562
 
563
+ /* Footer */
564
+ #footer {
565
+ text-align: center;
566
+ padding: 30px 20px;
567
+ margin-top: 50px;
568
+ background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
569
+ border-radius: 15px;
570
+ color: white;
571
+ }
572
 
573
+ #footer h3 {
574
+ margin: 0 0 10px 0;
575
+ font-size: 1.5em;
576
+ font-weight: 700;
577
+ }
578
 
579
+ #footer p {
580
+ margin: 5px 0;
581
+ opacity: 0.9;
582
+ }
583
 
584
+ /* Radio Buttons */
585
+ .radio-group label {
586
+ padding: 12px 20px !important;
587
+ border-radius: 10px !important;
588
+ font-weight: 600 !important;
589
+ transition: all 0.3s ease !important;
590
+ }
591
 
592
+ /* Dropdowns */
593
+ .dropdown select {
594
+ border-radius: 10px !important;
595
+ padding: 12px !important;
596
+ font-size: 1.05em !important;
597
+ border: 2px solid #e0e0e0 !important;
598
+ transition: all 0.3s ease !important;
599
+ }
600
 
601
+ .dropdown select:focus {
602
+ border-color: #667eea !important;
603
+ box-shadow: 0 0 0 3px rgba(102, 126, 234, 0.1) !important;
604
+ }
605
+ """
606
 
607
  # ============================================================================
608
  # Gradio Interface
609
  # ============================================================================
610
 
611
+ with gr.Blocks(css=custom_css, theme=gr.themes.Soft(primary_hue="indigo")) as demo:
612
+
613
+ # Hero Section
614
+ gr.HTML("""
615
+ <div id="hero-section">
616
+ <h1>🎡 DeepFilterNet2 Audio Enhancement</h1>
617
+ <p>Transform noisy audio into crystal-clear sound using cutting-edge AI technology</p>
618
+ <p style="font-size: 0.95em; margin-top: 15px;">
619
+ ✨ Real-time Processing | 🎯 State-of-the-Art Quality | πŸš€ Lightning Fast
620
+ </p>
621
+ </div>
622
+ """)
623
+
624
+ # Quick Start Guide
625
+ with gr.Row():
626
+ gr.Markdown("""
627
+ <div class="info-box">
628
+ <h3>πŸš€ Quick Start Guide</h3>
629
+ <ul>
630
+ <li><strong>Step 1:</strong> Upload an audio file or record using your microphone</li>
631
+ <li><strong>Step 2:</strong> Optionally add synthetic noise to test the denoiser</li>
632
+ <li><strong>Step 3:</strong> Adjust SNR settings if needed</li>
633
+ <li><strong>Step 4:</strong> Click the "Denoise Audio" button</li>
634
+ <li><strong>Step 5:</strong> Compare results with interactive spectrograms</li>
635
+ </ul>
636
+ </div>
637
+ """)
638
+
639
+ # Main Interface
640
+ with gr.Row():
641
+ # Left Column - Input Controls
642
+ with gr.Column(scale=1):
643
+ gr.HTML('<h2 class="section-header">πŸ“€ Audio Input</h2>')
644
+
645
+ with gr.Group(elem_classes="input-controls"):
646
+ input_mode = gr.Radio(
647
+ ["file", "mic"],
648
+ value="file",
649
+ label="πŸŽ™οΈ Input Method",
650
+ info="Choose your preferred input source",
651
+ elem_classes="radio-group"
652
+ )
653
+
654
+ audio_file = gr.Audio(
655
+ type="filepath",
656
+ label="πŸ“ Upload Audio File",
657
+ visible=True,
658
+ elem_classes="audio-wrapper"
659
+ )
660
+
661
+ mic_input = gr.Audio(
662
+ sources=["microphone"],
663
+ type="filepath",
664
+ label="🎀 Record Audio",
665
+ visible=False,
666
+ elem_classes="audio-wrapper"
667
+ )
668
+
669
+ gr.HTML('<h2 class="section-header">βš™οΈ Enhancement Settings</h2>')
670
+
671
+ with gr.Group(elem_classes="feature-card"):
672
+ noise_type = gr.Dropdown(
673
+ label="πŸ”Š Background Noise Type",
674
+ choices=list(NOISES.keys()),
675
+ value="None",
676
+ info="Add synthetic noise for testing",
677
+ elem_classes="dropdown"
678
+ )
679
+
680
+ snr = gr.Dropdown(
681
+ label="πŸ“Š Signal-to-Noise Ratio (dB)",
682
+ choices=["-5", "0", "10", "20"],
683
+ value="10",
684
+ info="Higher = cleaner signal",
685
+ elem_classes="dropdown"
686
+ )
687
+
688
+ process_btn = gr.Button(
689
+ "πŸš€ Denoise Audio",
690
+ elem_classes="process-button",
691
+ size="lg"
692
+ )
693
+
694
+ # Right Column - Results
695
+ with gr.Column(scale=2):
696
+ gr.HTML('<h2 class="section-header">πŸ“Š Results & Comparison</h2>')
697
+
698
+ with gr.Tabs():
699
+ with gr.Tab("πŸ”΄ Input Audio", elem_classes="output-section"):
700
+ noisy_audio = gr.Audio(
701
+ type="filepath",
702
+ label="Original/Noisy Audio",
703
+ elem_classes="audio-wrapper"
704
+ )
705
+ noisy_spec = gr.Image(
706
+ label="Input Spectrogram",
707
+ elem_classes="audio-wrapper"
708
+ )
709
+
710
+ with gr.Tab("🟒 Enhanced Audio", elem_classes="output-section"):
711
+ enhanced_audio = gr.Audio(
712
+ type="filepath",
713
+ label="Enhanced Audio",
714
+ elem_classes="audio-wrapper"
715
+ )
716
+ enhanced_spec = gr.Image(
717
+ label="Enhanced Spectrogram",
718
+ elem_classes="audio-wrapper"
719
+ )
720
+
721
+ # Examples Section
722
+ gr.HTML('<h2 class="section-header">🎯 Try These Examples</h2>')
723
+
724
+ with gr.Group(elem_classes="examples-section"):
725
+ gr.Examples(
726
+ examples=[
727
+ ["./samples/p232_013_clean.wav", "🍳 Kitchen", "10"],
728
+ ["./samples/p232_013_clean.wav", "β˜• Cafe", "10"],
729
+ ["./samples/p232_019_clean.wav", "β˜• Cafe", "10"],
730
+ ["./samples/p232_019_clean.wav", "🌊 River", "10"],
731
+ ],
732
+ inputs=[audio_file, noise_type, snr],
733
+ outputs=[noisy_audio, noisy_spec, enhanced_audio, enhanced_spec],
734
+ fn=process_audio,
735
+ cache_examples=True,
736
+ label="Click any example to try it instantly",
737
+ )
738
+
739
+ # Technical Information
740
+ with gr.Row():
741
+ with gr.Column():
742
+ gr.Markdown("""
743
+ <div class="info-box">
744
+ <h3>πŸ’‘ How It Works</h3>
745
+ <p><strong>DeepFilterNet2</strong> uses advanced deep learning to identify and remove unwanted background noise while preserving speech clarity. The model analyzes spectral patterns to distinguish between signal and noise components.</p>
746
+ </div>
747
+ """)
748
+
749
+ with gr.Column():
750
+ gr.Markdown("""
751
+ <div class="info-box">
752
+ <h3>πŸ“‹ Technical Specifications</h3>
753
+ <ul>
754
+ <li><strong>Model:</strong> DeepFilterNet2 (State-of-the-art)</li>
755
+ <li><strong>Sample Rate:</strong> 48 kHz</li>
756
+ <li><strong>Max Duration:</strong> 1 hour</li>
757
+ <li><strong>Formats:</strong> WAV, MP3, M4A, OGG, FLAC, AAC</li>
758
+ <li><strong>Processing:</strong> Real-time capable</li>
759
+ </ul>
760
+ </div>
761
+ """)
762
+
763
+ # Footer
764
+ gr.HTML("""
765
+ <div id="footer">
766
+ <h3>🎡 Powered by DeepFilterNet2</h3>
767
+ <p>Advanced AI-driven audio enhancement technology</p>
768
+ <p><em>Built with Gradio β€’ Optimized for Performance</em></p>
769
+ </div>
770
+ """)
771
+
772
+ # Event Handlers
773
  process_btn.click(
774
  fn=process_audio,
775
+ inputs=[audio_file, noise_type, snr, mic_input],
776
+ outputs=[noisy_audio, noisy_spec, enhanced_audio, enhanced_spec],
777
+ api_name="denoise",
778
  )
779
+
780
+ input_mode.change(
781
+ fn=toggle_input_mode,
782
+ inputs=input_mode,
783
+ outputs=[mic_input, audio_file],
784
+ )
785
+
786
+ # Initial cleanup
787
+ file_manager.cleanup_tmp()
788
 
789
+ # Launch application
790
  if __name__ == "__main__":
791
+ demo.queue().launch(
792
+ server_name="0.0.0.0",
793
+ server_port=7860,
794
+ share=False,
795
+ )