ongudidan commited on
Commit
d67897c
·
verified ·
1 Parent(s): ea29c1a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +40 -11
app.py CHANGED
@@ -108,49 +108,78 @@ def demo_fn(speech_upl: str, noise_type: str, snr: int, mic_input: Optional[str]
108
  snr = int(snr)
109
  noise_fn = NOISES[noise_type]
110
  meta = AudioMetaData(-1, -1, -1, -1, "")
111
- max_s = 10 # limit to 10 seconds
 
112
  if speech_upl is not None:
113
  sample, meta = load_audio(speech_upl, sr)
114
- max_len = max_s * sr
115
- if sample.shape[-1] > max_len:
116
- start = torch.randint(0, sample.shape[-1] - max_len, ()).item()
117
- sample = sample[..., start : start + max_len]
118
  else:
119
  sample, meta = load_audio("samples/p232_013_clean.wav", sr)
120
- sample = sample[..., : max_s * sr]
 
121
  if sample.dim() > 1 and sample.shape[0] > 1:
122
  assert (
123
  sample.shape[1] > sample.shape[0]
124
  ), f"Expecting channels first, but got {sample.shape}"
125
  sample = sample.mean(dim=0, keepdim=True)
 
126
  logger.info(f"Loaded sample with shape {sample.shape}")
 
 
127
  if noise_fn is not None:
128
  noise, _ = load_audio(noise_fn, sr) # type: ignore
129
  logger.info(f"Loaded noise with shape {noise.shape}")
130
  _, _, sample = mix_at_snr(sample, noise, snr)
 
131
  logger.info("Start denoising audio")
132
- enhanced = enhance(model, df, sample)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
133
  logger.info("Denoising finished")
134
- lim = torch.linspace(0.0, 1.0, int(sr * 0.15)).unsqueeze(0)
135
- lim = torch.cat((lim, torch.ones(1, enhanced.shape[1] - lim.shape[1])), dim=1)
136
- enhanced = enhanced * lim
137
  if meta.sample_rate != sr:
138
  enhanced = resample(enhanced, sr, meta.sample_rate)
139
  sample = resample(sample, sr, meta.sample_rate)
140
  sr = meta.sample_rate
 
 
141
  noisy_wav = tempfile.NamedTemporaryFile(suffix="noisy.wav", delete=False).name
142
  save_audio(noisy_wav, sample, sr)
143
  enhanced_wav = tempfile.NamedTemporaryFile(suffix="enhanced.wav", delete=False).name
144
  save_audio(enhanced_wav, enhanced, sr)
145
- logger.info(f"saved audios: {noisy_wav}, {enhanced_wav}")
 
 
146
  ax_noisy.clear()
147
  ax_enh.clear()
148
  noisy_im = spec_im(sample, sr=sr, figure=fig_noisy, ax=ax_noisy)
149
  enh_im = spec_im(enhanced, sr=sr, figure=fig_enh, ax=ax_enh)
 
 
150
  filter = [speech_upl, noisy_wav, enhanced_wav]
151
  if mic_input is not None and mic_input != "":
152
  filter.append(mic_input)
153
  cleanup_tmp(filter)
 
154
  return noisy_wav, noisy_im, enhanced_wav, enh_im
155
 
156
 
 
108
  snr = int(snr)
109
  noise_fn = NOISES[noise_type]
110
  meta = AudioMetaData(-1, -1, -1, -1, "")
111
+
112
+ # --- Load audio (full file, no hard 10s crop) ---
113
  if speech_upl is not None:
114
  sample, meta = load_audio(speech_upl, sr)
 
 
 
 
115
  else:
116
  sample, meta = load_audio("samples/p232_013_clean.wav", sr)
117
+
118
+ # Mix to mono if multi-channel
119
  if sample.dim() > 1 and sample.shape[0] > 1:
120
  assert (
121
  sample.shape[1] > sample.shape[0]
122
  ), f"Expecting channels first, but got {sample.shape}"
123
  sample = sample.mean(dim=0, keepdim=True)
124
+
125
  logger.info(f"Loaded sample with shape {sample.shape}")
126
+
127
+ # Add noise if requested
128
  if noise_fn is not None:
129
  noise, _ = load_audio(noise_fn, sr) # type: ignore
130
  logger.info(f"Loaded noise with shape {noise.shape}")
131
  _, _, sample = mix_at_snr(sample, noise, snr)
132
+
133
  logger.info("Start denoising audio")
134
+
135
+ # --- Process in chunks instead of single forward pass ---
136
+ chunk_size = sr * 10 # 10 seconds per chunk
137
+ enhanced_chunks = []
138
+
139
+ for i in range(0, sample.shape[-1], chunk_size):
140
+ chunk = sample[..., i:i + chunk_size]
141
+ if chunk.shape[-1] == 0:
142
+ continue
143
+ logger.info(f"Enhancing chunk {i//chunk_size + 1}")
144
+ enhanced_chunk = enhance(model, df, chunk)
145
+
146
+ # Apply short fade-in to smooth first chunk
147
+ if i == 0:
148
+ lim = torch.linspace(0.0, 1.0, int(sr * 0.15)).unsqueeze(0)
149
+ lim = torch.cat((lim, torch.ones(1, enhanced_chunk.shape[1] - lim.shape[1])), dim=1)
150
+ enhanced_chunk = enhanced_chunk * lim
151
+
152
+ enhanced_chunks.append(enhanced_chunk)
153
+
154
+ # Concatenate all enhanced chunks into one
155
+ enhanced = torch.cat(enhanced_chunks, dim=-1)
156
  logger.info("Denoising finished")
157
+
158
+ # Resample back if needed
 
159
  if meta.sample_rate != sr:
160
  enhanced = resample(enhanced, sr, meta.sample_rate)
161
  sample = resample(sample, sr, meta.sample_rate)
162
  sr = meta.sample_rate
163
+
164
+ # Save noisy & enhanced wavs
165
  noisy_wav = tempfile.NamedTemporaryFile(suffix="noisy.wav", delete=False).name
166
  save_audio(noisy_wav, sample, sr)
167
  enhanced_wav = tempfile.NamedTemporaryFile(suffix="enhanced.wav", delete=False).name
168
  save_audio(enhanced_wav, enhanced, sr)
169
+ logger.info(f"Saved audios: {noisy_wav}, {enhanced_wav}")
170
+
171
+ # Plot spectrograms
172
  ax_noisy.clear()
173
  ax_enh.clear()
174
  noisy_im = spec_im(sample, sr=sr, figure=fig_noisy, ax=ax_noisy)
175
  enh_im = spec_im(enhanced, sr=sr, figure=fig_enh, ax=ax_enh)
176
+
177
+ # Cleanup temp files
178
  filter = [speech_upl, noisy_wav, enhanced_wav]
179
  if mic_input is not None and mic_input != "":
180
  filter.append(mic_input)
181
  cleanup_tmp(filter)
182
+
183
  return noisy_wav, noisy_im, enhanced_wav, enh_im
184
 
185