ongudidan commited on
Commit
a030856
·
verified ·
1 Parent(s): 6024d1d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +52 -53
app.py CHANGED
@@ -8,9 +8,11 @@ from typing import List, Optional, Tuple, Union
8
  import subprocess
9
  # import os
10
 
11
- import asyncio
 
 
12
  # from typing import Optional
13
-
14
 
15
  import gradio as gr
16
  import matplotlib.pyplot as plt
@@ -107,30 +109,17 @@ def load_audio_gradio(
107
  return audio, meta
108
 
109
 
110
- def ensure_wav(filepath: str) -> str:
111
- """Convert MP3 (or other formats) to WAV using ffmpeg if needed."""
112
- if filepath.lower().endswith(".mp3"):
113
- wav_path = filepath.rsplit(".", 1)[0] + ".wav"
114
- subprocess.run(["ffmpeg", "-y", "-i", filepath, wav_path], check=True)
115
- return wav_path
116
- return filepath
117
-
118
-
119
 
120
 
121
- async def ensure_wav_async(filepath: str) -> str:
122
- """Async wrapper for FFmpeg conversion."""
123
- if filepath.lower().endswith(".mp3"):
124
- wav_path = filepath.rsplit(".", 1)[0] + ".wav"
125
- # Run ffmpeg in a thread to avoid blocking
126
- loop = asyncio.get_running_loop()
127
- await loop.run_in_executor(None, lambda: subprocess.run(["ffmpeg", "-y", "-i", filepath, wav_path], check=True))
128
- return wav_path
129
- return filepath
130
 
 
 
 
 
 
 
131
 
132
- async def demo_fn(speech_upl: str, noise_type: str, snr: int, mic_input: Optional[str] = None, progress=gr.Progress()):
133
-
134
  if mic_input:
135
  speech_upl = mic_input
136
 
@@ -140,60 +129,70 @@ async def demo_fn(speech_upl: str, noise_type: str, snr: int, mic_input: Optiona
140
  meta = AudioMetaData(-1, -1, -1, -1, "")
141
 
142
  max_s = 3600 # 1 hour
 
 
143
 
144
- # Stage 1: Upload / Convert
145
- progress(0, desc="Converting audio...")
146
- speech_upl = await ensure_wav_async(speech_upl)
147
-
148
- # Stage 2: Load audio
149
- progress(10, desc="Loading audio...")
150
- sample, meta = await asyncio.to_thread(load_audio, speech_upl, sr)
151
 
152
- max_len = max_s * sr
153
- if sample.shape[-1] > max_len:
154
- start = torch.randint(0, sample.shape[-1] - max_len, ()).item()
155
- sample = sample[..., start : start + max_len]
156
 
 
157
  if sample.dim() > 1 and sample.shape[0] > 1:
158
  sample = sample.mean(dim=0, keepdim=True)
159
 
160
- # Stage 3: Mix noise if applicable
161
- progress(30, desc="Mixing noise...")
162
  if noise_fn is not None:
163
- noise, _ = await asyncio.to_thread(load_audio, noise_fn, sr)
164
- _, _, sample = await asyncio.to_thread(mix_at_snr, sample, noise, snr)
 
 
 
165
 
166
- # Stage 4: Denoising
167
- progress(60, desc="Denoising...")
168
- enhanced = await asyncio.to_thread(enhance, model, df, sample)
 
 
169
 
 
 
 
 
 
 
 
170
  lim = torch.linspace(0.0, 1.0, int(sr * 0.15)).unsqueeze(0)
171
  lim = torch.cat((lim, torch.ones(1, enhanced.shape[1] - lim.shape[1])), dim=1)
172
  enhanced = enhanced * lim
173
 
 
174
  if meta.sample_rate != sr:
175
- enhanced = await asyncio.to_thread(resample, enhanced, sr, meta.sample_rate)
176
- sample = await asyncio.to_thread(resample, sample, sr, meta.sample_rate)
177
  sr = meta.sample_rate
178
 
179
- # Stage 5: Save outputs
180
- progress(90, desc="Saving files...")
181
  noisy_wav = tempfile.NamedTemporaryFile(suffix="noisy.wav", delete=False).name
182
  enhanced_wav = tempfile.NamedTemporaryFile(suffix="enhanced.wav", delete=False).name
 
 
183
 
184
- await asyncio.to_thread(save_audio, noisy_wav, sample, sr)
185
- await asyncio.to_thread(save_audio, enhanced_wav, enhanced, sr)
 
 
 
186
 
187
- progress(100, desc="Done!")
 
188
 
189
- # Optional: generate spectrograms (can also be offloaded to thread)
190
- noisy_im = await asyncio.to_thread(spec_im, sample, sr=sr, figure=fig_noisy, ax=ax_noisy)
191
- enh_im = await asyncio.to_thread(spec_im, enhanced, sr=sr, figure=fig_enh, ax=ax_enh)
192
 
193
- # Cleanup temp files
194
- cleanup_tmp([speech_upl, noisy_wav, enhanced_wav])
195
 
196
- return noisy_wav, noisy_im, enhanced_wav, enh_im
197
 
198
  def specshow(
199
  spec,
 
8
  import subprocess
9
  # import os
10
 
11
+ # import torch
12
+ # import numpy as np
13
+ # import tempfile
14
  # from typing import Optional
15
+ # import gradio as gr
16
 
17
  import gradio as gr
18
  import matplotlib.pyplot as plt
 
109
  return audio, meta
110
 
111
 
 
 
 
 
 
 
 
 
 
112
 
113
 
 
 
 
 
 
 
 
 
 
114
 
115
+ def chunk_audio(sample: torch.Tensor, chunk_size: int):
116
+ """Yield chunks of audio of size `chunk_size`."""
117
+ total_len = sample.shape[-1]
118
+ for start in range(0, total_len, chunk_size):
119
+ end = min(start + chunk_size, total_len)
120
+ yield sample[..., start:end], start, total_len
121
 
122
+ def demo_fn(speech_upl: str, noise_type: str, snr: int, mic_input: Optional[str] = None, progress=gr.Progress()):
 
123
  if mic_input:
124
  speech_upl = mic_input
125
 
 
129
  meta = AudioMetaData(-1, -1, -1, -1, "")
130
 
131
  max_s = 3600 # 1 hour
132
+ chunk_s = 10 # process in 10-second chunks
133
+ chunk_len = chunk_s * sr
134
 
135
+ # Load audio
136
+ speech_upl = ensure_wav(speech_upl)
137
+ sample, meta = load_audio(speech_upl, sr)
 
 
 
 
138
 
139
+ # Limit to max_s
140
+ if sample.shape[-1] > max_s * sr:
141
+ start_idx = torch.randint(0, sample.shape[-1] - max_s*sr, ()).item()
142
+ sample = sample[..., start_idx:start_idx + max_s*sr]
143
 
144
+ # Convert to mono if needed
145
  if sample.dim() > 1 and sample.shape[0] > 1:
146
  sample = sample.mean(dim=0, keepdim=True)
147
 
148
+ # Mix noise if applicable
 
149
  if noise_fn is not None:
150
+ noise, _ = load_audio(noise_fn, sr)
151
+ _, _, sample = mix_at_snr(sample, noise, snr)
152
+
153
+ # Prepare output tensor
154
+ enhanced_chunks = []
155
 
156
+ # Process audio in chunks
157
+ for i, (chunk, start, total_len) in enumerate(chunk_audio(sample, chunk_len)):
158
+ # Denoise the chunk
159
+ enhanced_chunk = enhance(model, df, chunk)
160
+ enhanced_chunks.append(enhanced_chunk)
161
 
162
+ # Update progress
163
+ progress((start + chunk.shape[-1]) / total_len * 100, desc="Denoising audio...")
164
+
165
+ # Concatenate all chunks
166
+ enhanced = torch.cat(enhanced_chunks, dim=-1)
167
+
168
+ # Optional: apply fade or limiter
169
  lim = torch.linspace(0.0, 1.0, int(sr * 0.15)).unsqueeze(0)
170
  lim = torch.cat((lim, torch.ones(1, enhanced.shape[1] - lim.shape[1])), dim=1)
171
  enhanced = enhanced * lim
172
 
173
+ # Resample if needed
174
  if meta.sample_rate != sr:
175
+ enhanced = resample(enhanced, sr, meta.sample_rate)
176
+ sample = resample(sample, sr, meta.sample_rate)
177
  sr = meta.sample_rate
178
 
179
+ # Save outputs
 
180
  noisy_wav = tempfile.NamedTemporaryFile(suffix="noisy.wav", delete=False).name
181
  enhanced_wav = tempfile.NamedTemporaryFile(suffix="enhanced.wav", delete=False).name
182
+ save_audio(noisy_wav, sample, sr)
183
+ save_audio(enhanced_wav, enhanced, sr)
184
 
185
+ # Spectrograms
186
+ ax_noisy.clear()
187
+ ax_enh.clear()
188
+ noisy_im = spec_im(sample, sr=sr, figure=fig_noisy, ax=ax_noisy)
189
+ enh_im = spec_im(enhanced, sr=sr, figure=fig_enh, ax=ax_enh)
190
 
191
+ cleanup_tmp([speech_upl, noisy_wav, enhanced_wav])
192
+ return noisy_wav, noisy_im, enhanced_wav, enh_im
193
 
 
 
 
194
 
 
 
195
 
 
196
 
197
  def specshow(
198
  spec,