ongudidan commited on
Commit
6024d1d
·
verified ·
1 Parent(s): faa83be

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +51 -36
app.py CHANGED
@@ -8,6 +8,10 @@ from typing import List, Optional, Tuple, Union
8
  import subprocess
9
  # import os
10
 
 
 
 
 
11
  import gradio as gr
12
  import matplotlib.pyplot as plt
13
  import numpy as np
@@ -112,71 +116,82 @@ def ensure_wav(filepath: str) -> str:
112
  return filepath
113
 
114
 
115
- def demo_fn(speech_upl: str, noise_type: str, snr: int, mic_input: Optional[str] = None):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
116
  if mic_input:
117
  speech_upl = mic_input
118
 
119
  sr = config("sr", 48000, int, section="df")
120
- logger.info(f"Got parameters speech_upl: {speech_upl}, noise: {noise_type}, snr: {snr}")
121
  snr = int(snr)
122
  noise_fn = NOISES[noise_type]
123
  meta = AudioMetaData(-1, -1, -1, -1, "")
124
 
125
- max_s = 3600 # allow up to 1 hour (3600 seconds)
126
 
127
- if speech_upl is not None:
128
- # Ensure compatible WAV input
129
- speech_upl = ensure_wav(speech_upl)
130
 
131
- sample, meta = load_audio(speech_upl, sr)
132
- max_len = max_s * sr
133
- if sample.shape[-1] > max_len:
134
- start = torch.randint(0, sample.shape[-1] - max_len, ()).item()
135
- sample = sample[..., start : start + max_len]
136
- else:
137
- sample, meta = load_audio("samples/p232_013_clean.wav", sr)
138
- sample = sample[..., : max_s * sr]
139
 
140
  if sample.dim() > 1 and sample.shape[0] > 1:
141
- assert sample.shape[1] > sample.shape[0], f"Expecting channels first, but got {sample.shape}"
142
  sample = sample.mean(dim=0, keepdim=True)
143
 
144
- logger.info(f"Loaded sample with shape {sample.shape}")
145
-
146
  if noise_fn is not None:
147
- noise, _ = load_audio(noise_fn, sr) # type: ignore
148
- logger.info(f"Loaded noise with shape {noise.shape}")
149
- _, _, sample = mix_at_snr(sample, noise, snr)
150
 
151
- logger.info("Start denoising audio")
152
- enhanced = enhance(model, df, sample)
153
- logger.info("Denoising finished")
154
 
155
  lim = torch.linspace(0.0, 1.0, int(sr * 0.15)).unsqueeze(0)
156
  lim = torch.cat((lim, torch.ones(1, enhanced.shape[1] - lim.shape[1])), dim=1)
157
  enhanced = enhanced * lim
158
 
159
  if meta.sample_rate != sr:
160
- enhanced = resample(enhanced, sr, meta.sample_rate)
161
- sample = resample(sample, sr, meta.sample_rate)
162
  sr = meta.sample_rate
163
 
 
 
164
  noisy_wav = tempfile.NamedTemporaryFile(suffix="noisy.wav", delete=False).name
165
- save_audio(noisy_wav, sample, sr)
166
  enhanced_wav = tempfile.NamedTemporaryFile(suffix="enhanced.wav", delete=False).name
167
- save_audio(enhanced_wav, enhanced, sr)
168
 
169
- logger.info(f"saved audios: {noisy_wav}, {enhanced_wav}")
 
 
 
170
 
171
- ax_noisy.clear()
172
- ax_enh.clear()
173
- noisy_im = spec_im(sample, sr=sr, figure=fig_noisy, ax=ax_noisy)
174
- enh_im = spec_im(enhanced, sr=sr, figure=fig_enh, ax=ax_enh)
175
 
176
- filter = [speech_upl, noisy_wav, enhanced_wav]
177
- if mic_input is not None and mic_input != "":
178
- filter.append(mic_input)
179
- cleanup_tmp(filter)
180
 
181
  return noisy_wav, noisy_im, enhanced_wav, enh_im
182
 
 
8
  import subprocess
9
  # import os
10
 
11
+ import asyncio
12
+ # from typing import Optional
13
+
14
+
15
  import gradio as gr
16
  import matplotlib.pyplot as plt
17
  import numpy as np
 
116
  return filepath
117
 
118
 
119
+
120
+
121
+ async def ensure_wav_async(filepath: str) -> str:
122
+ """Async wrapper for FFmpeg conversion."""
123
+ if filepath.lower().endswith(".mp3"):
124
+ wav_path = filepath.rsplit(".", 1)[0] + ".wav"
125
+ # Run ffmpeg in a thread to avoid blocking
126
+ loop = asyncio.get_running_loop()
127
+ await loop.run_in_executor(None, lambda: subprocess.run(["ffmpeg", "-y", "-i", filepath, wav_path], check=True))
128
+ return wav_path
129
+ return filepath
130
+
131
+
132
+ async def demo_fn(speech_upl: str, noise_type: str, snr: int, mic_input: Optional[str] = None, progress=gr.Progress()):
133
+
134
  if mic_input:
135
  speech_upl = mic_input
136
 
137
  sr = config("sr", 48000, int, section="df")
 
138
  snr = int(snr)
139
  noise_fn = NOISES[noise_type]
140
  meta = AudioMetaData(-1, -1, -1, -1, "")
141
 
142
+ max_s = 3600 # 1 hour
143
 
144
+ # Stage 1: Upload / Convert
145
+ progress(0, desc="Converting audio...")
146
+ speech_upl = await ensure_wav_async(speech_upl)
147
 
148
+ # Stage 2: Load audio
149
+ progress(10, desc="Loading audio...")
150
+ sample, meta = await asyncio.to_thread(load_audio, speech_upl, sr)
151
+
152
+ max_len = max_s * sr
153
+ if sample.shape[-1] > max_len:
154
+ start = torch.randint(0, sample.shape[-1] - max_len, ()).item()
155
+ sample = sample[..., start : start + max_len]
156
 
157
  if sample.dim() > 1 and sample.shape[0] > 1:
 
158
  sample = sample.mean(dim=0, keepdim=True)
159
 
160
+ # Stage 3: Mix noise if applicable
161
+ progress(30, desc="Mixing noise...")
162
  if noise_fn is not None:
163
+ noise, _ = await asyncio.to_thread(load_audio, noise_fn, sr)
164
+ _, _, sample = await asyncio.to_thread(mix_at_snr, sample, noise, snr)
 
165
 
166
+ # Stage 4: Denoising
167
+ progress(60, desc="Denoising...")
168
+ enhanced = await asyncio.to_thread(enhance, model, df, sample)
169
 
170
  lim = torch.linspace(0.0, 1.0, int(sr * 0.15)).unsqueeze(0)
171
  lim = torch.cat((lim, torch.ones(1, enhanced.shape[1] - lim.shape[1])), dim=1)
172
  enhanced = enhanced * lim
173
 
174
  if meta.sample_rate != sr:
175
+ enhanced = await asyncio.to_thread(resample, enhanced, sr, meta.sample_rate)
176
+ sample = await asyncio.to_thread(resample, sample, sr, meta.sample_rate)
177
  sr = meta.sample_rate
178
 
179
+ # Stage 5: Save outputs
180
+ progress(90, desc="Saving files...")
181
  noisy_wav = tempfile.NamedTemporaryFile(suffix="noisy.wav", delete=False).name
 
182
  enhanced_wav = tempfile.NamedTemporaryFile(suffix="enhanced.wav", delete=False).name
 
183
 
184
+ await asyncio.to_thread(save_audio, noisy_wav, sample, sr)
185
+ await asyncio.to_thread(save_audio, enhanced_wav, enhanced, sr)
186
+
187
+ progress(100, desc="Done!")
188
 
189
+ # Optional: generate spectrograms (can also be offloaded to thread)
190
+ noisy_im = await asyncio.to_thread(spec_im, sample, sr=sr, figure=fig_noisy, ax=ax_noisy)
191
+ enh_im = await asyncio.to_thread(spec_im, enhanced, sr=sr, figure=fig_enh, ax=ax_enh)
 
192
 
193
+ # Cleanup temp files
194
+ cleanup_tmp([speech_upl, noisy_wav, enhanced_wav])
 
 
195
 
196
  return noisy_wav, noisy_im, enhanced_wav, enh_im
197