vera6 commited on
Commit
e343fe3
·
verified ·
1 Parent(s): 01e806a

Update app/app.py

Browse files
Files changed (1) hide show
  1. app/app.py +148 -145
app/app.py CHANGED
@@ -1,12 +1,12 @@
1
- import fastapi
2
- import shutil
3
- import os
4
- import zipfile
5
- import io
6
- import uvicorn
7
  import threading
8
  import glob
9
- from typing import List
10
  import torch
11
  import gdown
12
  from soundfile import write
@@ -14,97 +14,25 @@ from torchaudio import load
14
  from librosa import resample
15
  import logging
16
 
17
- import librosa
18
- import numpy as np
19
- from scipy.signal import butter, filtfilt
20
- from scipy.ndimage import uniform_filter1d
21
- from scipy.signal import hilbert
22
-
23
  logging.basicConfig(level=logging.DEBUG)
24
 
25
  from sgmse import ScoreModel
26
  from sgmse.util.other import pad_spec
27
 
28
 
29
- def gentle_noise_reduction(audio, sr):
30
- """Very gentle noise reduction - only remove obvious noise"""
31
-
32
- # Only target very quiet background noise
33
- abs_audio = np.abs(audio)
34
- noise_threshold = np.percentile(abs_audio, 5) # Bottom 5% only
35
-
36
- # Very conservative gating - only suppress very quiet parts
37
- gate_threshold = noise_threshold * 1.5 # Very low threshold
38
- mask = abs_audio > gate_threshold
39
-
40
- # Smooth the mask heavily to avoid artifacts
41
- window_size = int(0.05 * sr) # 50ms smoothing
42
- if window_size % 2 == 0:
43
- window_size += 1
44
-
45
- mask_smooth = uniform_filter1d(mask.astype(float), size=window_size)
46
- mask_smooth = np.clip(mask_smooth, 0.8, 1.0) # Never go below 80%
47
-
48
- return audio * mask_smooth
49
-
50
- def minimal_speech_boost(audio, sr):
51
- """Minimal boost to speech frequencies"""
52
-
53
- # Very light boost to mid frequencies (1-3 kHz) - critical for PESQ
54
- nyquist = sr / 2
55
- low_freq = 1000 / nyquist
56
- high_freq = 3000 / nyquist
57
-
58
- # Design a very gentle bandpass filter
59
- b, a = butter(2, [low_freq, high_freq], btype='band') # Order 2 only
60
- mid_freq_content = filtfilt(b, a, audio)
61
-
62
- # Very small boost - only 2%
63
- boost_amount = 0.02
64
- enhanced_audio = audio + boost_amount * mid_freq_content
65
-
66
- return enhanced_audio
67
-
68
-
69
- def conservative_enhancement(enhanced_file):
70
- """Very conservative enhancement - minimal processing for small improvements"""
71
-
72
- audio, sr = librosa.load(enhanced_file, sr=16000)
73
- original_audio = audio.copy()
74
- original_length = len(audio)
75
-
76
- # Step 1: Very light noise reduction (only remove obvious noise)
77
- audio_denoised = gentle_noise_reduction(audio, sr)
78
-
79
- # Step 2: Minimal speech clarity boost
80
- audio_enhanced = minimal_speech_boost(audio_denoised, sr)
81
-
82
- # Ensure same length
83
- if len(audio_enhanced) != original_length:
84
- if len(audio_enhanced) > original_length:
85
- audio_enhanced = audio_enhanced[:original_length]
86
- else:
87
- audio_enhanced = np.pad(audio_enhanced, (0, original_length - len(audio_enhanced)), mode='constant')
88
-
89
- # Very conservative blending - mostly keep original
90
- blend_ratio = 0.15 # Only 15% enhancement, 85% original
91
- audio_result = blend_ratio * audio_enhanced + (1 - blend_ratio) * original_audio
92
-
93
- return audio_result
94
-
95
  class ModelAPI:
96
-
97
  def __init__(self, host, port):
98
-
99
- self.host = host
100
  self.port = port
101
-
102
  self.base_path = os.path.join(os.path.expanduser("~"), ".modelapi")
103
  self.noisy_audio_path = os.path.join(self.base_path, "noisy_audio")
104
  self.enhanced_audio_path = os.path.join(self.base_path, "enhanced_audio")
105
  app_dir = os.path.dirname(os.path.abspath(__file__))
106
- self.ckpt_path = os.path.join(app_dir,"miner_49.ckpt")
107
- self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
108
  self.corrector = "ald"
109
  self.corrector_steps = 1
110
  self.snr = 0.5
@@ -113,10 +41,10 @@ class ModelAPI:
113
  for audio_path in [self.noisy_audio_path, self.enhanced_audio_path]:
114
  if not os.path.exists(audio_path):
115
  os.makedirs(audio_path)
116
-
117
  for filename in os.listdir(audio_path):
118
  file_path = os.path.join(audio_path, filename)
119
-
120
  try:
121
  if os.path.isfile(file_path) or os.path.islink(file_path):
122
  os.unlink(file_path)
@@ -124,29 +52,45 @@ class ModelAPI:
124
  shutil.rmtree(file_path)
125
  except Exception as e:
126
  raise e
127
-
128
  self.app = fastapi.FastAPI()
129
  self._setup_routes()
130
 
131
  def _prepare(self):
 
 
 
 
 
 
132
  self.model = ScoreModel.load_from_checkpoint(self.ckpt_path, self.device)
133
  self.model.t_eps = 0.03
134
  self.model.eval()
135
 
136
  def _enhance(self):
137
- if self.model.backbone == 'ncsnpp_48k':
 
 
 
 
 
 
 
 
 
138
  target_sr = 48000
139
  pad_mode = "reflection"
140
- elif self.model.backbone == 'ncsnpp_v2':
141
  target_sr = 16000
142
  pad_mode = "reflection"
 
143
  else:
144
  target_sr = 16000
145
  pad_mode = "zero_pad"
146
 
147
- noisy_files = sorted(glob.glob(os.path.join(self.noisy_audio_path, '*.wav')))
148
- for noisy_file in noisy_files:
149
-
150
  filename = noisy_file.replace(self.noisy_audio_path, "")
151
  filename = filename[1:] if filename.startswith("/") else filename
152
 
@@ -155,91 +99,115 @@ class ModelAPI:
155
  if sr != target_sr:
156
  y = torch.tensor(resample(y.numpy(), orig_sr=sr, target_sr=target_sr))
157
 
158
- T_orig = y.size(1)
159
 
160
  # Normalize
161
  norm_factor = y.abs().max()
162
  y = y / norm_factor
163
-
164
  # Prepare DNN input
165
- Y = torch.unsqueeze(self.model._forward_transform(self.model._stft(y.to(self.device))), 0)
 
 
166
  Y = pad_spec(Y, mode=pad_mode)
167
-
168
  # Reverse sampling
169
- if self.model.sde.__class__.__name__ == 'OUVESDE':
170
- if self.model.sde.sampler_type == 'pc':
171
- sampler = self.model.get_pc_sampler('reverse_diffusion', self.corrector, Y.to(self.device), N=self.N,
172
- corrector_steps=self.corrector_steps, snr=self.snr)
173
- elif self.model.sde.sampler_type == 'ode':
 
 
 
 
 
 
174
  sampler = self.model.get_ode_sampler(Y.to(self.device), N=self.N)
175
  else:
176
  raise ValueError(f"Sampler type {args.sampler_type} not supported")
177
- elif self.model.sde.__class__.__name__ == 'SBVESDE':
178
- sampler_type = 'ode' if self.model.sde.sampler_type == 'pc' else self.model.sde.sampler_type
179
- sampler = self.model.get_sb_sampler(sde=self.model.sde, y=Y.cuda(), sampler_type=sampler_type)
 
 
 
 
 
 
180
  else:
181
- raise ValueError(f"SDE {self.model.sde.__class__.__name__} not supported")
182
-
 
 
183
  sample, _ = sampler()
184
-
185
  x_hat = self.model.to_audio(sample.squeeze(), T_orig)
186
 
187
  x_hat = x_hat * norm_factor
188
-
189
- os.makedirs(os.path.dirname(os.path.join(self.enhanced_audio_path, filename)), exist_ok=True)
190
- enhanced_file = os.path.join(self.enhanced_audio_path, filename)
191
- write(enhanced_file, x_hat.cpu().numpy(), target_sr)
192
 
193
- try:
194
- audio_enhanced = conservative_enhancement(enhanced_file)
195
- write(enhanced_file, audio_enhanced, target_sr)
196
- except Exception as e:
197
- write(enhanced_file, x_hat.cpu().numpy(), target_sr)
198
-
 
 
 
 
199
  def _setup_routes(self):
200
  self.app.get("/status/")(self.get_status)
201
  self.app.post("/prepare/")(self.prepare)
202
  self.app.post("/upload-audio/")(self.upload_audio)
203
  self.app.post("/enhance/")(self.enhance_audio)
204
  self.app.get("/download-enhanced/")(self.download_enhanced)
205
-
 
206
  def get_status(self):
207
  try:
208
  return {"container_running": True}
209
  except Exception as e:
210
  logging.error(f"Error getting status: {e}")
211
- raise fastapi.HTTPException(status_code=500, detail="An error occurred while fetching API status.")
212
-
 
 
213
  def prepare(self):
214
  try:
215
  self._prepare()
216
- return {'preparations': True}
217
  except Exception as e:
218
  logging.error(f"Error during preparations: {e}")
219
- return fastapi.HTTPException(status_code=500, detail="An error occurred while fetching API status.")
220
-
 
 
221
  def upload_audio(self, files: List[fastapi.UploadFile] = fastapi.File(...)):
222
-
223
  uploaded_files = []
224
-
225
  for file in files:
226
- try:
227
  file_path = os.path.join(self.noisy_audio_path, file.filename)
228
-
229
  with open(file_path, "wb") as f:
230
- while contents := file.file.read(1024*1024):
231
  f.write(contents)
232
 
233
- uploaded_files.append(file.filename)
234
-
235
  except Exception as e:
236
- logging.error(f"Error uploading files: {e}")
237
- raise fastapi.HTTPException(status_code=500, detail="An error occurred while uploading the noisy files.")
 
 
 
238
  finally:
239
  file.file.close()
240
-
241
  print(f"uploaded files: {uploaded_files}")
242
-
243
  return {"uploaded_files": uploaded_files, "status": True}
244
 
245
  def enhance_audio(self):
@@ -247,35 +215,70 @@ class ModelAPI:
247
  # Enhance audio
248
  self._enhance()
249
  # Obtain list of file paths for enhanced audio
250
- wav_files = glob.glob(os.path.join(self.enhanced_audio_path, '*.wav'))
251
  # Extract just the file names
252
  enhanced_files = [os.path.basename(file) for file in wav_files]
253
  return {"status": True}
254
-
255
  except Exception as e:
256
  print(f"Exception occured during enhancement: {e}")
257
- raise fastapi.HTTPException(status_code=500, detail="An error occurred while enhancing the noisy files.")
258
-
 
 
 
259
  def download_enhanced(self):
260
  try:
261
  zip_buffer = io.BytesIO()
262
 
263
  with zipfile.ZipFile(zip_buffer, "w") as zip_file:
264
- for wav_file in glob.glob(os.path.join(self.enhanced_audio_path, '*.wav')):
 
 
265
  zip_file.write(wav_file, arcname=os.path.basename(wav_file))
266
  zip_buffer.seek(0)
267
 
268
-
269
  return fastapi.responses.StreamingResponse(
270
  iter([zip_buffer.getvalue()]), # Stream the in-memory content
271
  media_type="application/zip",
272
- headers={"Content-Disposition": "attachment; filename=enhanced_audio_files.zip"}
 
 
273
  )
274
 
275
  except Exception as e:
276
  logging.error(f"Error during enhanced files download: {e}")
277
- raise fastapi.HTTPException(status_code=500, detail=f"An error occurred while creating the download file: {str(e)}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
278
 
279
  def run(self):
280
-
281
- uvicorn.run(self.app, host=self.host, port=self.port)
 
1
+ import fastapi
2
+ import shutil
3
+ import os
4
+ import zipfile
5
+ import io
6
+ import uvicorn
7
  import threading
8
  import glob
9
+ from typing import List
10
  import torch
11
  import gdown
12
  from soundfile import write
 
14
  from librosa import resample
15
  import logging
16
 
 
 
 
 
 
 
17
  logging.basicConfig(level=logging.DEBUG)
18
 
19
  from sgmse import ScoreModel
20
  from sgmse.util.other import pad_spec
21
 
22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  class ModelAPI:
24
+
25
  def __init__(self, host, port):
26
+
27
+ self.host = host
28
  self.port = port
29
+
30
  self.base_path = os.path.join(os.path.expanduser("~"), ".modelapi")
31
  self.noisy_audio_path = os.path.join(self.base_path, "noisy_audio")
32
  self.enhanced_audio_path = os.path.join(self.base_path, "enhanced_audio")
33
  app_dir = os.path.dirname(os.path.abspath(__file__))
34
+ self.ckpt_path = glob.glob(os.path.join(app_dir, "*.ckpt"))[0]
35
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
36
  self.corrector = "ald"
37
  self.corrector_steps = 1
38
  self.snr = 0.5
 
41
  for audio_path in [self.noisy_audio_path, self.enhanced_audio_path]:
42
  if not os.path.exists(audio_path):
43
  os.makedirs(audio_path)
44
+
45
  for filename in os.listdir(audio_path):
46
  file_path = os.path.join(audio_path, filename)
47
+
48
  try:
49
  if os.path.isfile(file_path) or os.path.islink(file_path):
50
  os.unlink(file_path)
 
52
  shutil.rmtree(file_path)
53
  except Exception as e:
54
  raise e
55
+
56
  self.app = fastapi.FastAPI()
57
  self._setup_routes()
58
 
59
  def _prepare(self):
60
+ """Miners should modify this function to fit their fine-tuned models.
61
+
62
+ This function will make any preparations necessary to initialize the
63
+ speech enhancement model (i.e. downloading checkpoint files, etc.)
64
+ """
65
+
66
  self.model = ScoreModel.load_from_checkpoint(self.ckpt_path, self.device)
67
  self.model.t_eps = 0.03
68
  self.model.eval()
69
 
70
  def _enhance(self):
71
+ """
72
+ Miners should modify this function to fit their fine-tuned models.
73
+
74
+ This function will:
75
+ 1. Open each noisy .wav file
76
+ 2. Enhance the audio with the model
77
+ 3. Save the enhanced audio in .wav format to ModelAPI.enhanced_audio_path
78
+ """
79
+
80
+ if self.model.backbone == "ncsnpp_48k":
81
  target_sr = 48000
82
  pad_mode = "reflection"
83
+ elif self.model.backbone == "ncsnpp_v2":
84
  target_sr = 16000
85
  pad_mode = "reflection"
86
+ print("using ncsnpp_v2")
87
  else:
88
  target_sr = 16000
89
  pad_mode = "zero_pad"
90
 
91
+ noisy_files = sorted(glob.glob(os.path.join(self.noisy_audio_path, "*.wav")))
92
+ for noisy_file in noisy_files:
93
+
94
  filename = noisy_file.replace(self.noisy_audio_path, "")
95
  filename = filename[1:] if filename.startswith("/") else filename
96
 
 
99
  if sr != target_sr:
100
  y = torch.tensor(resample(y.numpy(), orig_sr=sr, target_sr=target_sr))
101
 
102
+ T_orig = y.size(1)
103
 
104
  # Normalize
105
  norm_factor = y.abs().max()
106
  y = y / norm_factor
107
+
108
  # Prepare DNN input
109
+ Y = torch.unsqueeze(
110
+ self.model._forward_transform(self.model._stft(y.to(self.device))), 0
111
+ )
112
  Y = pad_spec(Y, mode=pad_mode)
113
+
114
  # Reverse sampling
115
+ if self.model.sde.__class__.__name__ == "OUVESDE":
116
+ if self.model.sde.sampler_type == "pc":
117
+ sampler = self.model.get_pc_sampler(
118
+ "reverse_diffusion",
119
+ self.corrector,
120
+ Y.to(self.device),
121
+ N=self.N,
122
+ corrector_steps=self.corrector_steps,
123
+ snr=self.snr,
124
+ )
125
+ elif self.model.sde.sampler_type == "ode":
126
  sampler = self.model.get_ode_sampler(Y.to(self.device), N=self.N)
127
  else:
128
  raise ValueError(f"Sampler type {args.sampler_type} not supported")
129
+ elif self.model.sde.__class__.__name__ == "SBVESDE":
130
+ sampler_type = (
131
+ "ode"
132
+ if self.model.sde.sampler_type == "pc"
133
+ else self.model.sde.sampler_type
134
+ )
135
+ sampler = self.model.get_sb_sampler(
136
+ sde=self.model.sde, y=Y.cuda(), sampler_type=sampler_type
137
+ )
138
  else:
139
+ raise ValueError(
140
+ f"SDE {self.model.sde.__class__.__name__} not supported"
141
+ )
142
+
143
  sample, _ = sampler()
144
+
145
  x_hat = self.model.to_audio(sample.squeeze(), T_orig)
146
 
147
  x_hat = x_hat * norm_factor
 
 
 
 
148
 
149
+ os.makedirs(
150
+ os.path.dirname(os.path.join(self.enhanced_audio_path, filename)),
151
+ exist_ok=True,
152
+ )
153
+ write(
154
+ os.path.join(self.enhanced_audio_path, filename),
155
+ x_hat.cpu().numpy(),
156
+ target_sr,
157
+ )
158
+
159
  def _setup_routes(self):
160
  self.app.get("/status/")(self.get_status)
161
  self.app.post("/prepare/")(self.prepare)
162
  self.app.post("/upload-audio/")(self.upload_audio)
163
  self.app.post("/enhance/")(self.enhance_audio)
164
  self.app.get("/download-enhanced/")(self.download_enhanced)
165
+ self.app.post("/reset/")(self.reset)
166
+
167
  def get_status(self):
168
  try:
169
  return {"container_running": True}
170
  except Exception as e:
171
  logging.error(f"Error getting status: {e}")
172
+ raise fastapi.HTTPException(
173
+ status_code=500, detail="An error occurred while fetching API status."
174
+ )
175
+
176
  def prepare(self):
177
  try:
178
  self._prepare()
179
+ return {"preparations": True}
180
  except Exception as e:
181
  logging.error(f"Error during preparations: {e}")
182
+ return fastapi.HTTPException(
183
+ status_code=500, detail="An error occurred while fetching API status."
184
+ )
185
+
186
  def upload_audio(self, files: List[fastapi.UploadFile] = fastapi.File(...)):
187
+
188
  uploaded_files = []
189
+
190
  for file in files:
191
+ try:
192
  file_path = os.path.join(self.noisy_audio_path, file.filename)
193
+
194
  with open(file_path, "wb") as f:
195
+ while contents := file.file.read(1024 * 1024):
196
  f.write(contents)
197
 
198
+ uploaded_files.append(file.filename)
199
+
200
  except Exception as e:
201
+ logging.error(f"Error uploading files: {e}")
202
+ raise fastapi.HTTPException(
203
+ status_code=500,
204
+ detail="An error occurred while uploading the noisy files.",
205
+ )
206
  finally:
207
  file.file.close()
208
+
209
  print(f"uploaded files: {uploaded_files}")
210
+
211
  return {"uploaded_files": uploaded_files, "status": True}
212
 
213
  def enhance_audio(self):
 
215
  # Enhance audio
216
  self._enhance()
217
  # Obtain list of file paths for enhanced audio
218
+ wav_files = glob.glob(os.path.join(self.enhanced_audio_path, "*.wav"))
219
  # Extract just the file names
220
  enhanced_files = [os.path.basename(file) for file in wav_files]
221
  return {"status": True}
222
+
223
  except Exception as e:
224
  print(f"Exception occured during enhancement: {e}")
225
+ raise fastapi.HTTPException(
226
+ status_code=500,
227
+ detail="An error occurred while enhancing the noisy files.",
228
+ )
229
+
230
  def download_enhanced(self):
231
  try:
232
  zip_buffer = io.BytesIO()
233
 
234
  with zipfile.ZipFile(zip_buffer, "w") as zip_file:
235
+ for wav_file in glob.glob(
236
+ os.path.join(self.enhanced_audio_path, "*.wav")
237
+ ):
238
  zip_file.write(wav_file, arcname=os.path.basename(wav_file))
239
  zip_buffer.seek(0)
240
 
 
241
  return fastapi.responses.StreamingResponse(
242
  iter([zip_buffer.getvalue()]), # Stream the in-memory content
243
  media_type="application/zip",
244
+ headers={
245
+ "Content-Disposition": "attachment; filename=enhanced_audio_files.zip"
246
+ },
247
  )
248
 
249
  except Exception as e:
250
  logging.error(f"Error during enhanced files download: {e}")
251
+ raise fastapi.HTTPException(
252
+ status_code=500,
253
+ detail=f"An error occurred while creating the download file: {str(e)}",
254
+ )
255
+
256
+ def reset(self):
257
+ """
258
+ Removes all audio files in preparation for another batch of enhancement.
259
+ """
260
+ for directory in [self.noisy_audio_path, self.enhanced_audio_path]:
261
+ if not os.path.isdir(directory):
262
+ continue
263
+
264
+ for filename in os.listdir(directory):
265
+ filepath = os.path.join(directory, filename)
266
+ if os.path.isfile(filepath):
267
+ try:
268
+ os.remove(filepath)
269
+ except Exception as e:
270
+ print(f"Error removing {filepath}: {e}")
271
+ return {
272
+ "status": False,
273
+ "noisy": os.listdir(self.noisy_audio_path),
274
+ "enhanced": os.listdir(self.enhanced_audio_path),
275
+ }
276
+ return {
277
+ "status": True,
278
+ "noisy": os.listdir(self.noisy_audio_path),
279
+ "enhanced": os.listdir(self.enhanced_audio_path),
280
+ }
281
 
282
  def run(self):
283
+
284
+ uvicorn.run(self.app, host=self.host, port=self.port)