ongudidan commited on
Commit
cb1e8bf
·
verified ·
1 Parent(s): a030856

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +39 -53
app.py CHANGED
@@ -8,12 +8,6 @@ from typing import List, Optional, Tuple, Union
8
  import subprocess
9
  # import os
10
 
11
- # import torch
12
- # import numpy as np
13
- # import tempfile
14
- # from typing import Optional
15
- # import gradio as gr
16
-
17
  import gradio as gr
18
  import matplotlib.pyplot as plt
19
  import numpy as np
@@ -109,90 +103,82 @@ def load_audio_gradio(
109
  return audio, meta
110
 
111
 
 
 
 
 
 
 
 
112
 
113
 
114
-
115
- def chunk_audio(sample: torch.Tensor, chunk_size: int):
116
- """Yield chunks of audio of size `chunk_size`."""
117
- total_len = sample.shape[-1]
118
- for start in range(0, total_len, chunk_size):
119
- end = min(start + chunk_size, total_len)
120
- yield sample[..., start:end], start, total_len
121
-
122
- def demo_fn(speech_upl: str, noise_type: str, snr: int, mic_input: Optional[str] = None, progress=gr.Progress()):
123
  if mic_input:
124
  speech_upl = mic_input
125
 
126
  sr = config("sr", 48000, int, section="df")
 
127
  snr = int(snr)
128
  noise_fn = NOISES[noise_type]
129
  meta = AudioMetaData(-1, -1, -1, -1, "")
130
 
131
- max_s = 3600 # 1 hour
132
- chunk_s = 10 # process in 10-second chunks
133
- chunk_len = chunk_s * sr
134
 
135
- # Load audio
136
- speech_upl = ensure_wav(speech_upl)
137
- sample, meta = load_audio(speech_upl, sr)
138
 
139
- # Limit to max_s
140
- if sample.shape[-1] > max_s * sr:
141
- start_idx = torch.randint(0, sample.shape[-1] - max_s*sr, ()).item()
142
- sample = sample[..., start_idx:start_idx + max_s*sr]
 
 
 
 
143
 
144
- # Convert to mono if needed
145
  if sample.dim() > 1 and sample.shape[0] > 1:
 
146
  sample = sample.mean(dim=0, keepdim=True)
147
 
148
- # Mix noise if applicable
 
149
  if noise_fn is not None:
150
- noise, _ = load_audio(noise_fn, sr)
 
151
  _, _, sample = mix_at_snr(sample, noise, snr)
152
 
153
- # Prepare output tensor
154
- enhanced_chunks = []
155
-
156
- # Process audio in chunks
157
- for i, (chunk, start, total_len) in enumerate(chunk_audio(sample, chunk_len)):
158
- # Denoise the chunk
159
- enhanced_chunk = enhance(model, df, chunk)
160
- enhanced_chunks.append(enhanced_chunk)
161
-
162
- # Update progress
163
- progress((start + chunk.shape[-1]) / total_len * 100, desc="Denoising audio...")
164
-
165
- # Concatenate all chunks
166
- enhanced = torch.cat(enhanced_chunks, dim=-1)
167
 
168
- # Optional: apply fade or limiter
169
  lim = torch.linspace(0.0, 1.0, int(sr * 0.15)).unsqueeze(0)
170
  lim = torch.cat((lim, torch.ones(1, enhanced.shape[1] - lim.shape[1])), dim=1)
171
  enhanced = enhanced * lim
172
 
173
- # Resample if needed
174
  if meta.sample_rate != sr:
175
  enhanced = resample(enhanced, sr, meta.sample_rate)
176
  sample = resample(sample, sr, meta.sample_rate)
177
  sr = meta.sample_rate
178
 
179
- # Save outputs
180
  noisy_wav = tempfile.NamedTemporaryFile(suffix="noisy.wav", delete=False).name
181
- enhanced_wav = tempfile.NamedTemporaryFile(suffix="enhanced.wav", delete=False).name
182
  save_audio(noisy_wav, sample, sr)
 
183
  save_audio(enhanced_wav, enhanced, sr)
184
 
185
- # Spectrograms
 
186
  ax_noisy.clear()
187
  ax_enh.clear()
188
  noisy_im = spec_im(sample, sr=sr, figure=fig_noisy, ax=ax_noisy)
189
  enh_im = spec_im(enhanced, sr=sr, figure=fig_enh, ax=ax_enh)
190
 
191
- cleanup_tmp([speech_upl, noisy_wav, enhanced_wav])
192
- return noisy_wav, noisy_im, enhanced_wav, enh_im
193
-
194
-
195
 
 
196
 
197
  def specshow(
198
  spec,
@@ -373,7 +359,7 @@ with gr.Blocks() as demo:
373
 
374
  cleanup_tmp()
375
  # demo.launch(enable_queue=True)
376
- demo.launch()
377
- # demo.queue().launch()
378
 
379
 
 
8
  import subprocess
9
  # import os
10
 
 
 
 
 
 
 
11
  import gradio as gr
12
  import matplotlib.pyplot as plt
13
  import numpy as np
 
103
  return audio, meta
104
 
105
 
106
+ def ensure_wav(filepath: str) -> str:
107
+ """Convert MP3 (or other formats) to WAV using ffmpeg if needed."""
108
+ if filepath.lower().endswith(".mp3"):
109
+ wav_path = filepath.rsplit(".", 1)[0] + ".wav"
110
+ subprocess.run(["ffmpeg", "-y", "-i", filepath, wav_path], check=True)
111
+ return wav_path
112
+ return filepath
113
 
114
 
115
+ def demo_fn(speech_upl: str, noise_type: str, snr: int, mic_input: Optional[str] = None):
 
 
 
 
 
 
 
 
116
  if mic_input:
117
  speech_upl = mic_input
118
 
119
  sr = config("sr", 48000, int, section="df")
120
+ logger.info(f"Got parameters speech_upl: {speech_upl}, noise: {noise_type}, snr: {snr}")
121
  snr = int(snr)
122
  noise_fn = NOISES[noise_type]
123
  meta = AudioMetaData(-1, -1, -1, -1, "")
124
 
125
+ max_s = 3600 # allow up to 1 hour (3600 seconds)
 
 
126
 
127
+ if speech_upl is not None:
128
+ # Ensure compatible WAV input
129
+ speech_upl = ensure_wav(speech_upl)
130
 
131
+ sample, meta = load_audio(speech_upl, sr)
132
+ max_len = max_s * sr
133
+ if sample.shape[-1] > max_len:
134
+ start = torch.randint(0, sample.shape[-1] - max_len, ()).item()
135
+ sample = sample[..., start : start + max_len]
136
+ else:
137
+ sample, meta = load_audio("samples/p232_013_clean.wav", sr)
138
+ sample = sample[..., : max_s * sr]
139
 
 
140
  if sample.dim() > 1 and sample.shape[0] > 1:
141
+ assert sample.shape[1] > sample.shape[0], f"Expecting channels first, but got {sample.shape}"
142
  sample = sample.mean(dim=0, keepdim=True)
143
 
144
+ logger.info(f"Loaded sample with shape {sample.shape}")
145
+
146
  if noise_fn is not None:
147
+ noise, _ = load_audio(noise_fn, sr) # type: ignore
148
+ logger.info(f"Loaded noise with shape {noise.shape}")
149
  _, _, sample = mix_at_snr(sample, noise, snr)
150
 
151
+ logger.info("Start denoising audio")
152
+ enhanced = enhance(model, df, sample)
153
+ logger.info("Denoising finished")
 
 
 
 
 
 
 
 
 
 
 
154
 
 
155
  lim = torch.linspace(0.0, 1.0, int(sr * 0.15)).unsqueeze(0)
156
  lim = torch.cat((lim, torch.ones(1, enhanced.shape[1] - lim.shape[1])), dim=1)
157
  enhanced = enhanced * lim
158
 
 
159
  if meta.sample_rate != sr:
160
  enhanced = resample(enhanced, sr, meta.sample_rate)
161
  sample = resample(sample, sr, meta.sample_rate)
162
  sr = meta.sample_rate
163
 
 
164
  noisy_wav = tempfile.NamedTemporaryFile(suffix="noisy.wav", delete=False).name
 
165
  save_audio(noisy_wav, sample, sr)
166
+ enhanced_wav = tempfile.NamedTemporaryFile(suffix="enhanced.wav", delete=False).name
167
  save_audio(enhanced_wav, enhanced, sr)
168
 
169
+ logger.info(f"saved audios: {noisy_wav}, {enhanced_wav}")
170
+
171
  ax_noisy.clear()
172
  ax_enh.clear()
173
  noisy_im = spec_im(sample, sr=sr, figure=fig_noisy, ax=ax_noisy)
174
  enh_im = spec_im(enhanced, sr=sr, figure=fig_enh, ax=ax_enh)
175
 
176
+ filter = [speech_upl, noisy_wav, enhanced_wav]
177
+ if mic_input is not None and mic_input != "":
178
+ filter.append(mic_input)
179
+ cleanup_tmp(filter)
180
 
181
+ return noisy_wav, noisy_im, enhanced_wav, enh_im
182
 
183
  def specshow(
184
  spec,
 
359
 
360
  cleanup_tmp()
361
  # demo.launch(enable_queue=True)
362
+ # demo.launch()
363
+ demo.queue().launch()
364
 
365