danielr-ceva commited on
Commit
30f1424
·
verified ·
1 Parent(s): d0fa80e

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +541 -0
app.py ADDED
@@ -0,0 +1,541 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ import math
3
+ import tempfile
4
+ from dataclasses import dataclass
5
+ from pathlib import Path
6
+ from typing import Dict, Optional, Tuple
7
+
8
+ import gradio as gr
9
+ import librosa
10
+ import matplotlib.pyplot as plt
11
+ import numpy as np
12
+ import onnxruntime as ort
13
+ import soundfile as sf
14
+ from PIL import Image
15
+
16
+ # -----------------------------
17
+ # Configuration
18
+ # -----------------------------
19
+ MAX_SECONDS = 10.0
20
+ ONNX_DIR = Path("./hf_space/onnx")
21
+
22
+
23
+ @dataclass(frozen=True)
24
+ class ModelSpec:
25
+ name: str
26
+ sr: int
27
+ onnx_path: str
28
+
29
+
30
+ # -----------------------------
31
+ # Model discovery and metadata
32
+ # -----------------------------
33
+ def _infer_model_meta(model_name: str) -> int:
34
+ normalized = model_name.lower().replace("-", "_")
35
+
36
+ if "48khz" in normalized or "48k" in normalized or "48hr" in normalized:
37
+ return 48000
38
+
39
+ # Fallback for unknown 16 kHz DPDFNet variants
40
+ return 16000
41
+
42
+
43
+ def _display_label(spec: ModelSpec) -> str:
44
+ khz = int(spec.sr // 1000)
45
+ return f"{spec.name} ({khz} kHz)"
46
+
47
+
48
+ def discover_model_presets() -> Dict[str, ModelSpec]:
49
+ ordered_names = [
50
+ "baseline",
51
+ "dpdfnet2",
52
+ "dpdfnet4",
53
+ "dpdfnet8",
54
+ "dpdfnet2_48khz_hr",
55
+ ]
56
+
57
+ found_paths = {p.stem: p for p in ONNX_DIR.glob("*.onnx") if p.is_file()}
58
+ presets: Dict[str, ModelSpec] = {}
59
+
60
+ for name in ordered_names:
61
+ p = found_paths.get(name)
62
+ if p is None:
63
+ continue
64
+ sr = _infer_model_meta(name)
65
+ spec = ModelSpec(
66
+ name=name,
67
+ sr=sr,
68
+ onnx_path=str(p),
69
+ )
70
+ presets[_display_label(spec)] = spec
71
+
72
+ # Include any additional ONNX files not in the canonical order list.
73
+ for name, p in sorted(found_paths.items()):
74
+ if name in ordered_names:
75
+ continue
76
+ sr = _infer_model_meta(name)
77
+ spec = ModelSpec(
78
+ name=name,
79
+ sr=sr,
80
+ onnx_path=str(p),
81
+ )
82
+ presets[_display_label(spec)] = spec
83
+
84
+ return presets
85
+
86
+
87
+ MODEL_PRESETS = discover_model_presets()
88
+ DEFAULT_MODEL_KEY = next(iter(MODEL_PRESETS), None)
89
+
90
+
91
+ # -----------------------------
92
+ # ONNX Runtime + frontend cache
93
+ # -----------------------------
94
+ _SESSIONS: Dict[str, ort.InferenceSession] = {}
95
+ _INIT_STATES: Dict[str, np.ndarray] = {}
96
+
97
+
98
+ def resolve_model_path(local_path: str) -> str:
99
+ p = Path(local_path)
100
+ if p.exists():
101
+ return str(p)
102
+ raise gr.Error(
103
+ f"ONNX model not found at: {local_path}. "
104
+ "Expected local models under ./onnx/."
105
+ )
106
+
107
+
108
+ def get_ort_session(model_key: str) -> ort.InferenceSession:
109
+ if model_key in _SESSIONS:
110
+ return _SESSIONS[model_key]
111
+
112
+ spec = MODEL_PRESETS[model_key]
113
+ onnx_path = resolve_model_path(spec.onnx_path)
114
+
115
+ sess = ort.InferenceSession(
116
+ onnx_path,
117
+ providers=["CPUExecutionProvider"],
118
+ )
119
+ _SESSIONS[model_key] = sess
120
+ return sess
121
+
122
+
123
+ def _resolve_state_path(model_key: str) -> Path:
124
+ spec = MODEL_PRESETS[model_key]
125
+ model_path = Path(spec.onnx_path)
126
+ state_path = model_path.with_name(f"{model_path.stem}_state.npz")
127
+ if not state_path.is_file():
128
+ raise gr.Error(f"State file not found: {state_path}")
129
+ return state_path
130
+
131
+
132
+ def _load_initial_state(model_key: str, session: ort.InferenceSession) -> np.ndarray:
133
+ if model_key in _INIT_STATES:
134
+ return _INIT_STATES[model_key]
135
+
136
+ state_path = _resolve_state_path(model_key)
137
+ with np.load(state_path) as data:
138
+ if "init_state" not in data:
139
+ raise gr.Error(f"Missing 'init_state' key in state file: {state_path}")
140
+ init_state = np.ascontiguousarray(data["init_state"].astype(np.float32, copy=False))
141
+
142
+ expected_shape = session.get_inputs()[1].shape
143
+ if len(expected_shape) != init_state.ndim:
144
+ raise gr.Error(
145
+ f"Initial state rank mismatch for {state_path.name}: expected={expected_shape}, got={tuple(init_state.shape)}"
146
+ )
147
+ for exp_dim, act_dim in zip(expected_shape, init_state.shape):
148
+ if isinstance(exp_dim, int) and exp_dim != act_dim:
149
+ raise gr.Error(
150
+ f"Initial state shape mismatch for {state_path.name}: expected={expected_shape}, got={tuple(init_state.shape)}"
151
+ )
152
+
153
+ _INIT_STATES[model_key] = init_state
154
+ return init_state
155
+
156
+
157
+ # -----------------------------
158
+ # STFT/iSTFT (module-free)
159
+ # -----------------------------
160
+ def vorbis_window(window_len: int) -> np.ndarray:
161
+ window_size_h = window_len / 2
162
+ indices = np.arange(window_len)
163
+ sin = np.sin(0.5 * np.pi * (indices + 0.5) / window_size_h)
164
+ window = np.sin(0.5 * np.pi * sin * sin)
165
+ return window.astype(np.float32)
166
+
167
+
168
+ def get_wnorm(window_len: int, frame_size: int) -> float:
169
+ return 1.0 / (window_len ** 2 / (2 * frame_size))
170
+
171
+
172
+ def _infer_stft_params(model_key: str, session: ort.InferenceSession) -> Tuple[int, int, float, np.ndarray]:
173
+ # ONNX spec input is [B, T, F, 2] (or dynamic variants).
174
+ spec_shape = session.get_inputs()[0].shape
175
+ freq_bins = spec_shape[-2] if len(spec_shape) >= 2 else None
176
+
177
+ if isinstance(freq_bins, int) and freq_bins > 1:
178
+ win_len = int((freq_bins - 1) * 2)
179
+ else:
180
+ # 20 ms windows for DPDFNet family.
181
+ sr = MODEL_PRESETS[model_key].sr
182
+ win_len = int(round(sr * 0.02))
183
+
184
+ hop = win_len // 2
185
+ win = vorbis_window(win_len)
186
+ wnorm = get_wnorm(win_len, hop)
187
+ return win_len, hop, wnorm, win
188
+
189
+
190
+ def _preprocess_waveform(waveform: np.ndarray, win_len: int, hop: int, wnorm: float, win: np.ndarray) -> np.ndarray:
191
+ audio = np.asarray(waveform, dtype=np.float32).reshape(-1)
192
+ audio_pad = np.pad(audio, (0, win_len), mode="constant")
193
+
194
+ spec = librosa.stft(
195
+ y=audio_pad,
196
+ n_fft=win_len,
197
+ hop_length=hop,
198
+ win_length=win_len,
199
+ window=win,
200
+ center=True,
201
+ pad_mode="reflect",
202
+ )
203
+ spec = (spec.T * wnorm).astype(np.complex64, copy=False) # [T, F]
204
+ spec_ri = np.stack([spec.real, spec.imag], axis=-1).astype(np.float32, copy=False) # [T, F, 2]
205
+ return spec_ri[None, ...] # [1, T, F, 2]
206
+
207
+
208
+ def _postprocess_spec(spec_e: np.ndarray, win_len: int, hop: int, wnorm: float, win: np.ndarray) -> np.ndarray:
209
+ spec_c = np.asarray(spec_e[0], dtype=np.float32) # [T, F, 2]
210
+ spec = (spec_c[..., 0] + 1j * spec_c[..., 1]).T.astype(np.complex64, copy=False) # [F, T]
211
+
212
+ waveform_e = librosa.istft(
213
+ spec,
214
+ hop_length=hop,
215
+ win_length=win_len,
216
+ window=win,
217
+ center=True,
218
+ length=None,
219
+ ).astype(np.float32, copy=False)
220
+
221
+ waveform_e = waveform_e / wnorm
222
+ waveform_e = np.concatenate(
223
+ [waveform_e[win_len * 2 :], np.zeros(win_len * 2, dtype=np.float32)],
224
+ axis=0,
225
+ )
226
+ return waveform_e
227
+
228
+
229
+ # -----------------------------
230
+ # ONNX inference (non-streaming pre/post, streaming ONNX state loop)
231
+ # -----------------------------
232
+ def enhance_audio_onnx(
233
+ audio_mono: np.ndarray,
234
+ model_key: str,
235
+ ) -> np.ndarray:
236
+ sess = get_ort_session(model_key)
237
+
238
+ inputs = sess.get_inputs()
239
+ outputs = sess.get_outputs()
240
+ if len(inputs) < 2 or len(outputs) < 2:
241
+ raise gr.Error(
242
+ "Expected streaming ONNX signature with 2 inputs (spec, state) and 2 outputs (spec_e, state_out)."
243
+ )
244
+
245
+ in_spec_name = inputs[0].name
246
+ in_state_name = inputs[1].name
247
+ out_spec_name = outputs[0].name
248
+ out_state_name = outputs[1].name
249
+
250
+ waveform = np.asarray(audio_mono, dtype=np.float32).reshape(-1)
251
+ win_len, hop, wnorm, win = _infer_stft_params(model_key, sess)
252
+ spec_r_np = _preprocess_waveform(waveform, win_len=win_len, hop=hop, wnorm=wnorm, win=win)
253
+
254
+ state = _load_initial_state(model_key, sess).copy()
255
+ spec_e_frames = []
256
+ num_frames = int(spec_r_np.shape[1])
257
+
258
+ for t in range(num_frames):
259
+ spec_t = np.ascontiguousarray(spec_r_np[:, t : t + 1, :, :], dtype=np.float32)
260
+ spec_e_t, state = sess.run(
261
+ [out_spec_name, out_state_name],
262
+ {in_spec_name: spec_t, in_state_name: state},
263
+ )
264
+ spec_e_frames.append(np.ascontiguousarray(spec_e_t, dtype=np.float32))
265
+
266
+ if not spec_e_frames:
267
+ return waveform
268
+
269
+ spec_e_np = np.concatenate(spec_e_frames, axis=1)
270
+ waveform_e = _postprocess_spec(spec_e_np, win_len=win_len, hop=hop, wnorm=wnorm, win=win)
271
+ return np.asarray(waveform_e, dtype=np.float32).reshape(-1)
272
+
273
+
274
+ # -----------------------------
275
+ # Audio utilities
276
+ # -----------------------------
277
+ def _load_wav_from_gradio_path(path: str) -> Tuple[np.ndarray, int]:
278
+ data, sr = sf.read(path, always_2d=True)
279
+ data = data.astype(np.float32, copy=False)
280
+ return data, int(sr)
281
+
282
+
283
+ def _to_mono(x: np.ndarray) -> Tuple[np.ndarray, int]:
284
+ if x.ndim == 1:
285
+ return x.astype(np.float32, copy=False), 1
286
+ if x.shape[1] == 1:
287
+ return x[:, 0], 1
288
+ return x.mean(axis=1), int(x.shape[1])
289
+
290
+
291
+ def _resample(y: np.ndarray, sr_in: int, sr_out: int) -> np.ndarray:
292
+ if sr_in == sr_out:
293
+ return y
294
+ return librosa.resample(y, orig_sr=sr_in, target_sr=sr_out).astype(np.float32, copy=False)
295
+
296
+
297
+ def _match_length(y: np.ndarray, target_len: int) -> np.ndarray:
298
+ if len(y) == target_len:
299
+ return y
300
+ if len(y) > target_len:
301
+ return y[:target_len]
302
+ out = np.zeros((target_len,), dtype=y.dtype)
303
+ out[: len(y)] = y
304
+ return out
305
+
306
+
307
+ def _save_wav(y: np.ndarray, sr: int, prefix: str) -> str:
308
+ tmp = tempfile.NamedTemporaryFile(prefix=prefix, suffix=".wav", delete=False)
309
+ tmp.close()
310
+ sf.write(tmp.name, y, sr)
311
+ return tmp.name
312
+
313
+
314
+ def _spectrogram_image(y: np.ndarray, sr: int) -> Image.Image:
315
+ win_length = max(256, int(0.032 * sr))
316
+ hop_length = max(64, int(0.008 * sr))
317
+ n_fft = 1 << (int(math.ceil(math.log2(win_length))))
318
+
319
+ S = librosa.stft(y, n_fft=n_fft, hop_length=hop_length, win_length=win_length, center=False)
320
+ S_db = librosa.amplitude_to_db(np.abs(S) + 1e-10, ref=np.max)
321
+
322
+ fig, ax = plt.subplots(figsize=(8.4, 3.2))
323
+ ax.imshow(S_db, origin="lower", aspect="auto")
324
+ ax.set_axis_off()
325
+ fig.subplots_adjust(left=0, right=1, top=1, bottom=0)
326
+
327
+ buf = io.BytesIO()
328
+ fig.savefig(buf, format="png", dpi=160)
329
+ plt.close(fig)
330
+ buf.seek(0)
331
+ return Image.open(buf)
332
+
333
+
334
+ # -----------------------------
335
+ # Main pipeline
336
+ # -----------------------------
337
+ def run_enhancement(
338
+ source: str,
339
+ mic_path: Optional[str],
340
+ file_path: Optional[str],
341
+ model_key: str,
342
+ ):
343
+ if not MODEL_PRESETS:
344
+ raise gr.Error("No ONNX models found under ./onnx/. Add models and retry.")
345
+
346
+ chosen_path = mic_path if source == "Microphone" else file_path
347
+ if not chosen_path:
348
+ raise gr.Error("Please provide audio either from the microphone or by uploading a file.")
349
+
350
+ x, sr_orig = _load_wav_from_gradio_path(chosen_path)
351
+ y_mono, n_ch = _to_mono(x)
352
+
353
+ max_samples = int(MAX_SECONDS * sr_orig)
354
+ was_trimmed = len(y_mono) > max_samples
355
+ if was_trimmed:
356
+ y_mono = y_mono[:max_samples]
357
+ dur = len(y_mono) / float(sr_orig)
358
+
359
+ spec = MODEL_PRESETS[model_key]
360
+ sr_model = spec.sr
361
+
362
+ y_model = _resample(y_mono, sr_orig, sr_model)
363
+ y_enh_model = enhance_audio_onnx(y_model, model_key)
364
+
365
+ y_enh = _resample(y_enh_model, sr_model, sr_orig)
366
+ y_enh = _match_length(y_enh, len(y_mono))
367
+
368
+ noisy_out = _save_wav(y_mono, sr_orig, prefix="noisy_mono_")
369
+ enh_out = _save_wav(y_enh, sr_orig, prefix="enhanced_")
370
+
371
+ noisy_img = _spectrogram_image(y_mono, sr_orig)
372
+ enh_img = _spectrogram_image(y_enh, sr_orig)
373
+
374
+ status = (
375
+ f"**Input:** {sr_orig} Hz, {dur:.2f}s, channels={n_ch} ⭢ mono\n\n"
376
+ f"**Model:** {spec.name} (runs at {sr_model} Hz)\n\n"
377
+ + (
378
+ f"**Resampling:** {sr_orig} ⭢ {sr_model} ⭢ {sr_orig}\n\n"
379
+ if sr_orig != sr_model
380
+ else "**Resampling:** none\n\n"
381
+ )
382
+ + (f"**Trimmed:** first {MAX_SECONDS:.0f}s used\n" if was_trimmed else "")
383
+ + "\n✅ Done."
384
+ )
385
+ return noisy_out, enh_out, noisy_img, enh_img, status
386
+
387
+
388
+ def set_source_visibility(source: str):
389
+ return (
390
+ gr.update(visible=(source == "Microphone")),
391
+ gr.update(visible=(source == "Upload")),
392
+ )
393
+
394
+
395
+ # -----------------------------
396
+ # UI (light polish)
397
+ # -----------------------------
398
+ THEME = gr.themes.Soft(
399
+ primary_hue="orange",
400
+ neutral_hue="slate",
401
+ font=[
402
+ "Arial",
403
+ "ui-sans-serif",
404
+ "system-ui",
405
+ "Segoe UI",
406
+ "Roboto",
407
+ "Helvetica Neue",
408
+ "Noto Sans",
409
+ "Liberation Sans",
410
+ "sans-serif",
411
+ ],
412
+ )
413
+
414
+ CSS = """
415
+ .gradio-container{
416
+ max-width: 1040px !important;
417
+ margin: 0 auto !important;
418
+ font-family: Arial, ui-sans-serif, system-ui, -apple-system, Segoe UI, Roboto, Helvetica Neue, Noto Sans, Liberation Sans, sans-serif !important;
419
+ }
420
+
421
+ #header {
422
+ padding: 14px 16px;
423
+ border-radius: 16px;
424
+ border: 1px solid rgba(0,0,0,0.08);
425
+ background: linear-gradient(135deg, rgba(255,152,0,0.14), rgba(255,152,0,0.04));
426
+ }
427
+ #header h1{
428
+ margin: 0;
429
+ font-size: 24px;
430
+ font-weight: 800;
431
+ letter-spacing: -0.2px;
432
+ }
433
+ #header p{
434
+ margin: 6px 0 0 0;
435
+ color: var(--body-text-color-subdued);
436
+ font-size: 13.5px;
437
+ line-height: 1.35;
438
+ }
439
+
440
+ .spec img { border-radius: 14px; }
441
+ .audio { border-radius: 14px !important; overflow: hidden; }
442
+
443
+ #run_btn{
444
+ border-radius: 12px !important;
445
+ font-weight: 800 !important;
446
+ }
447
+
448
+ #status_md p{ margin: 0.35rem 0; }
449
+ """
450
+
451
+ with gr.Blocks(theme=THEME, css=CSS, title="DPDFNet Speech Enhancement") as demo:
452
+ gr.HTML(
453
+ # """
454
+ # <div id="header">
455
+ # <h1>DPDFNet Speech Enhancement</h1>
456
+ # <p>
457
+ # Upload or record up to 10 seconds. Multi-channel inputs are averaged to mono.
458
+ # Choose any local ONNX model from <code>./onnx</code>.
459
+ # Pre/postprocessing uses the same non-streaming STFT/iSTFT flow as <code>streaming/infer_dpdfnet_onnx.py</code>.
460
+ # </p>
461
+ # </div>
462
+ # """
463
+ """
464
+ <div id="header" style="text-align: center; margin-bottom: 25px;">
465
+
466
+ <h1 style="margin-bottom: 6px;">DPDFNet Speech Enhancement</h1>
467
+
468
+ <p style="font-size: 14px; letter-spacing: 1px; margin-bottom: 14px; color: #555;">
469
+ Causal • Real-Time • Edge-Ready
470
+ </p>
471
+
472
+ <p style="max-width: 720px; margin: 0 auto; font-size: 15px; line-height: 1.6;">
473
+ DPDFNet extends DeepFilterNet2 with Dual-Path RNN blocks to improve
474
+ long-range temporal and cross-band modeling while preserving low latency.
475
+ Designed for single-channel streaming speech enhancement under challenging noise conditions.
476
+ </p>
477
+
478
+ <hr style="margin-top: 22px; border: none; height: 1px; background: linear-gradient(to right, transparent, #ddd, transparent);">
479
+
480
+ </div>
481
+ """
482
+ )
483
+
484
+ with gr.Row():
485
+ model_key = gr.Dropdown(
486
+ choices=list(MODEL_PRESETS.keys()),
487
+ value=DEFAULT_MODEL_KEY,
488
+ label="Model",
489
+ # info="Audio is resampled to model SR, enhanced with ONNX, then resampled back.",
490
+ interactive=True,
491
+ )
492
+
493
+ source = gr.Radio(
494
+ choices=["Microphone", "Upload"],
495
+ value="Upload",
496
+ label="Input source",
497
+ )
498
+
499
+ with gr.Row():
500
+ mic_audio = gr.Audio(
501
+ sources=["microphone"],
502
+ type="filepath",
503
+ format="wav",
504
+ label="Microphone (max 10s)",
505
+ visible=False,
506
+ buttons=["download"],
507
+ elem_classes=["audio"],
508
+ )
509
+ file_audio = gr.Audio(
510
+ sources=["upload"],
511
+ type="filepath",
512
+ format="wav",
513
+ label="Upload file (WAV/MP3/FLAC etc., max 10s)",
514
+ visible=True,
515
+ buttons=["download"],
516
+ elem_classes=["audio"],
517
+ )
518
+
519
+ run_btn = gr.Button("Enhance", variant="primary", elem_id="run_btn")
520
+ status = gr.Markdown(elem_id="status_md")
521
+
522
+ gr.Markdown("## Results")
523
+
524
+ with gr.Row():
525
+ out_noisy = gr.Audio(label="Before (mono)", interactive=False, format="wav", buttons=["download"], elem_classes=["audio"])
526
+ out_enh = gr.Audio(label="After (enhanced)", interactive=False, format="wav", buttons=["download"], elem_classes=["audio"])
527
+
528
+ with gr.Row():
529
+ img_noisy = gr.Image(label="Noisy spectrogram", elem_classes=["spec"])
530
+ img_enh = gr.Image(label="Enhanced spectrogram", elem_classes=["spec"])
531
+
532
+ source.change(fn=set_source_visibility, inputs=source, outputs=[mic_audio, file_audio])
533
+ run_btn.click(
534
+ fn=run_enhancement,
535
+ inputs=[source, mic_audio, file_audio, model_key],
536
+ outputs=[out_noisy, out_enh, img_noisy, img_enh, status],
537
+ api_name="enhance",
538
+ )
539
+
540
+ if __name__ == "__main__":
541
+ demo.queue(max_size=32).launch()