ongudidan commited on
Commit
a2adea6
·
verified ·
1 Parent(s): d67897c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +31 -34
app.py CHANGED
@@ -5,6 +5,9 @@ import tempfile
5
  import time
6
  from typing import List, Optional, Tuple, Union
7
 
 
 
 
8
  import gradio as gr
9
  import matplotlib.pyplot as plt
10
  import numpy as np
@@ -100,81 +103,76 @@ def load_audio_gradio(
100
  return audio, meta
101
 
102
 
 
 
 
 
 
 
 
 
 
103
  def demo_fn(speech_upl: str, noise_type: str, snr: int, mic_input: Optional[str] = None):
104
  if mic_input:
105
  speech_upl = mic_input
 
106
  sr = config("sr", 48000, int, section="df")
107
  logger.info(f"Got parameters speech_upl: {speech_upl}, noise: {noise_type}, snr: {snr}")
108
  snr = int(snr)
109
  noise_fn = NOISES[noise_type]
110
  meta = AudioMetaData(-1, -1, -1, -1, "")
111
 
112
- # --- Load audio (full file, no hard 10s crop) ---
 
113
  if speech_upl is not None:
 
 
 
114
  sample, meta = load_audio(speech_upl, sr)
 
 
 
 
115
  else:
116
  sample, meta = load_audio("samples/p232_013_clean.wav", sr)
 
117
 
118
- # Mix to mono if multi-channel
119
  if sample.dim() > 1 and sample.shape[0] > 1:
120
- assert (
121
- sample.shape[1] > sample.shape[0]
122
- ), f"Expecting channels first, but got {sample.shape}"
123
  sample = sample.mean(dim=0, keepdim=True)
124
 
125
  logger.info(f"Loaded sample with shape {sample.shape}")
126
 
127
- # Add noise if requested
128
  if noise_fn is not None:
129
  noise, _ = load_audio(noise_fn, sr) # type: ignore
130
  logger.info(f"Loaded noise with shape {noise.shape}")
131
  _, _, sample = mix_at_snr(sample, noise, snr)
132
 
133
  logger.info("Start denoising audio")
134
-
135
- # --- Process in chunks instead of single forward pass ---
136
- chunk_size = sr * 10 # 10 seconds per chunk
137
- enhanced_chunks = []
138
-
139
- for i in range(0, sample.shape[-1], chunk_size):
140
- chunk = sample[..., i:i + chunk_size]
141
- if chunk.shape[-1] == 0:
142
- continue
143
- logger.info(f"Enhancing chunk {i//chunk_size + 1}")
144
- enhanced_chunk = enhance(model, df, chunk)
145
-
146
- # Apply short fade-in to smooth first chunk
147
- if i == 0:
148
- lim = torch.linspace(0.0, 1.0, int(sr * 0.15)).unsqueeze(0)
149
- lim = torch.cat((lim, torch.ones(1, enhanced_chunk.shape[1] - lim.shape[1])), dim=1)
150
- enhanced_chunk = enhanced_chunk * lim
151
-
152
- enhanced_chunks.append(enhanced_chunk)
153
-
154
- # Concatenate all enhanced chunks into one
155
- enhanced = torch.cat(enhanced_chunks, dim=-1)
156
  logger.info("Denoising finished")
157
 
158
- # Resample back if needed
 
 
 
159
  if meta.sample_rate != sr:
160
  enhanced = resample(enhanced, sr, meta.sample_rate)
161
  sample = resample(sample, sr, meta.sample_rate)
162
  sr = meta.sample_rate
163
 
164
- # Save noisy & enhanced wavs
165
  noisy_wav = tempfile.NamedTemporaryFile(suffix="noisy.wav", delete=False).name
166
  save_audio(noisy_wav, sample, sr)
167
  enhanced_wav = tempfile.NamedTemporaryFile(suffix="enhanced.wav", delete=False).name
168
  save_audio(enhanced_wav, enhanced, sr)
169
- logger.info(f"Saved audios: {noisy_wav}, {enhanced_wav}")
170
 
171
- # Plot spectrograms
 
172
  ax_noisy.clear()
173
  ax_enh.clear()
174
  noisy_im = spec_im(sample, sr=sr, figure=fig_noisy, ax=ax_noisy)
175
  enh_im = spec_im(enhanced, sr=sr, figure=fig_enh, ax=ax_enh)
176
 
177
- # Cleanup temp files
178
  filter = [speech_upl, noisy_wav, enhanced_wav]
179
  if mic_input is not None and mic_input != "":
180
  filter.append(mic_input)
@@ -182,7 +180,6 @@ def demo_fn(speech_upl: str, noise_type: str, snr: int, mic_input: Optional[str]
182
 
183
  return noisy_wav, noisy_im, enhanced_wav, enh_im
184
 
185
-
186
  def specshow(
187
  spec,
188
  ax=None,
 
5
  import time
6
  from typing import List, Optional, Tuple, Union
7
 
8
+ import subprocess
9
+ # import os
10
+
11
  import gradio as gr
12
  import matplotlib.pyplot as plt
13
  import numpy as np
 
103
  return audio, meta
104
 
105
 
106
+ def ensure_wav(filepath: str) -> str:
107
+ """Convert MP3 (or other formats) to WAV using ffmpeg if needed."""
108
+ if filepath.lower().endswith(".mp3"):
109
+ wav_path = filepath.rsplit(".", 1)[0] + ".wav"
110
+ subprocess.run(["ffmpeg", "-y", "-i", filepath, wav_path], check=True)
111
+ return wav_path
112
+ return filepath
113
+
114
+
115
  def demo_fn(speech_upl: str, noise_type: str, snr: int, mic_input: Optional[str] = None):
116
  if mic_input:
117
  speech_upl = mic_input
118
+
119
  sr = config("sr", 48000, int, section="df")
120
  logger.info(f"Got parameters speech_upl: {speech_upl}, noise: {noise_type}, snr: {snr}")
121
  snr = int(snr)
122
  noise_fn = NOISES[noise_type]
123
  meta = AudioMetaData(-1, -1, -1, -1, "")
124
 
125
+ max_s = 3600 # allow up to 1 hour (3600 seconds)
126
+
127
  if speech_upl is not None:
128
+ # ✅ Ensure compatible WAV input
129
+ speech_upl = ensure_wav(speech_upl)
130
+
131
  sample, meta = load_audio(speech_upl, sr)
132
+ max_len = max_s * sr
133
+ if sample.shape[-1] > max_len:
134
+ start = torch.randint(0, sample.shape[-1] - max_len, ()).item()
135
+ sample = sample[..., start : start + max_len]
136
  else:
137
  sample, meta = load_audio("samples/p232_013_clean.wav", sr)
138
+ sample = sample[..., : max_s * sr]
139
 
 
140
  if sample.dim() > 1 and sample.shape[0] > 1:
141
+ assert sample.shape[1] > sample.shape[0], f"Expecting channels first, but got {sample.shape}"
 
 
142
  sample = sample.mean(dim=0, keepdim=True)
143
 
144
  logger.info(f"Loaded sample with shape {sample.shape}")
145
 
 
146
  if noise_fn is not None:
147
  noise, _ = load_audio(noise_fn, sr) # type: ignore
148
  logger.info(f"Loaded noise with shape {noise.shape}")
149
  _, _, sample = mix_at_snr(sample, noise, snr)
150
 
151
  logger.info("Start denoising audio")
152
+ enhanced = enhance(model, df, sample)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
153
  logger.info("Denoising finished")
154
 
155
+ lim = torch.linspace(0.0, 1.0, int(sr * 0.15)).unsqueeze(0)
156
+ lim = torch.cat((lim, torch.ones(1, enhanced.shape[1] - lim.shape[1])), dim=1)
157
+ enhanced = enhanced * lim
158
+
159
  if meta.sample_rate != sr:
160
  enhanced = resample(enhanced, sr, meta.sample_rate)
161
  sample = resample(sample, sr, meta.sample_rate)
162
  sr = meta.sample_rate
163
 
 
164
  noisy_wav = tempfile.NamedTemporaryFile(suffix="noisy.wav", delete=False).name
165
  save_audio(noisy_wav, sample, sr)
166
  enhanced_wav = tempfile.NamedTemporaryFile(suffix="enhanced.wav", delete=False).name
167
  save_audio(enhanced_wav, enhanced, sr)
 
168
 
169
+ logger.info(f"saved audios: {noisy_wav}, {enhanced_wav}")
170
+
171
  ax_noisy.clear()
172
  ax_enh.clear()
173
  noisy_im = spec_im(sample, sr=sr, figure=fig_noisy, ax=ax_noisy)
174
  enh_im = spec_im(enhanced, sr=sr, figure=fig_enh, ax=ax_enh)
175
 
 
176
  filter = [speech_upl, noisy_wav, enhanced_wav]
177
  if mic_input is not None and mic_input != "":
178
  filter.append(mic_input)
 
180
 
181
  return noisy_wav, noisy_im, enhanced_wav, enh_im
182
 
 
183
  def specshow(
184
  spec,
185
  ax=None,