vera6 commited on
Commit
60fbe39
·
verified ·
1 Parent(s): d0e7aef

Update app/app.py

Browse files
Files changed (1) hide show
  1. app/app.py +151 -67
app/app.py CHANGED
@@ -1,36 +1,38 @@
1
- import fastapi
2
- import shutil
3
- import os
4
- import zipfile
5
- import io
6
- import uvicorn
7
  import threading
8
  import glob
9
- from typing import List
10
  import torch
11
  import gdown
12
  from soundfile import write
13
  from torchaudio import load
14
  from librosa import resample
15
  import logging
 
16
  logging.basicConfig(level=logging.DEBUG)
17
 
18
  from sgmse import ScoreModel
19
  from sgmse.util.other import pad_spec
20
 
 
21
  class ModelAPI:
22
-
23
  def __init__(self, host, port):
24
-
25
- self.host = host
26
  self.port = port
27
-
28
  self.base_path = os.path.join(os.path.expanduser("~"), ".modelapi")
29
  self.noisy_audio_path = os.path.join(self.base_path, "noisy_audio")
30
  self.enhanced_audio_path = os.path.join(self.base_path, "enhanced_audio")
31
  app_dir = os.path.dirname(os.path.abspath(__file__))
32
- self.ckpt_path = os.path.join(app_dir,"miner_33.ckpt")
33
- self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
34
  self.corrector = "ald"
35
  self.corrector_steps = 1
36
  self.snr = 0.5
@@ -39,10 +41,10 @@ class ModelAPI:
39
  for audio_path in [self.noisy_audio_path, self.enhanced_audio_path]:
40
  if not os.path.exists(audio_path):
41
  os.makedirs(audio_path)
42
-
43
  for filename in os.listdir(audio_path):
44
  file_path = os.path.join(audio_path, filename)
45
-
46
  try:
47
  if os.path.isfile(file_path) or os.path.islink(file_path):
48
  os.unlink(file_path)
@@ -50,29 +52,45 @@ class ModelAPI:
50
  shutil.rmtree(file_path)
51
  except Exception as e:
52
  raise e
53
-
54
  self.app = fastapi.FastAPI()
55
  self._setup_routes()
56
 
57
  def _prepare(self):
 
 
 
 
 
 
58
  self.model = ScoreModel.load_from_checkpoint(self.ckpt_path, self.device)
59
  self.model.t_eps = 0.03
60
  self.model.eval()
61
 
62
  def _enhance(self):
63
- if self.model.backbone == 'ncsnpp_48k':
 
 
 
 
 
 
 
 
 
64
  target_sr = 48000
65
  pad_mode = "reflection"
66
- elif self.model.backbone == 'ncsnpp_v2':
67
  target_sr = 16000
68
  pad_mode = "reflection"
 
69
  else:
70
  target_sr = 16000
71
  pad_mode = "zero_pad"
72
 
73
- noisy_files = sorted(glob.glob(os.path.join(self.noisy_audio_path, '*.wav')))
74
- for noisy_file in noisy_files:
75
-
76
  filename = noisy_file.replace(self.noisy_audio_path, "")
77
  filename = filename[1:] if filename.startswith("/") else filename
78
 
@@ -81,84 +99,115 @@ class ModelAPI:
81
  if sr != target_sr:
82
  y = torch.tensor(resample(y.numpy(), orig_sr=sr, target_sr=target_sr))
83
 
84
- T_orig = y.size(1)
85
 
86
  # Normalize
87
  norm_factor = y.abs().max()
88
  y = y / norm_factor
89
-
90
  # Prepare DNN input
91
- Y = torch.unsqueeze(self.model._forward_transform(self.model._stft(y.to(self.device))), 0)
 
 
92
  Y = pad_spec(Y, mode=pad_mode)
93
-
94
  # Reverse sampling
95
- if self.model.sde.__class__.__name__ == 'OUVESDE':
96
- if self.model.sde.sampler_type == 'pc':
97
- sampler = self.model.get_pc_sampler('reverse_diffusion', self.corrector, Y.to(self.device), N=self.N,
98
- corrector_steps=self.corrector_steps, snr=self.snr)
99
- elif self.model.sde.sampler_type == 'ode':
 
 
 
 
 
 
100
  sampler = self.model.get_ode_sampler(Y.to(self.device), N=self.N)
101
  else:
102
  raise ValueError(f"Sampler type {args.sampler_type} not supported")
103
- elif self.model.sde.__class__.__name__ == 'SBVESDE':
104
- sampler_type = 'ode' if self.model.sde.sampler_type == 'pc' else self.model.sde.sampler_type
105
- sampler = self.model.get_sb_sampler(sde=self.model.sde, y=Y.cuda(), sampler_type=sampler_type)
 
 
 
 
 
 
106
  else:
107
- raise ValueError(f"SDE {self.model.sde.__class__.__name__} not supported")
108
-
 
 
109
  sample, _ = sampler()
110
-
111
  x_hat = self.model.to_audio(sample.squeeze(), T_orig)
112
 
113
  x_hat = x_hat * norm_factor
114
-
115
- os.makedirs(os.path.dirname(os.path.join(self.enhanced_audio_path, filename)), exist_ok=True)
116
- write(os.path.join(self.enhanced_audio_path, filename), x_hat.cpu().numpy(), target_sr)
117
-
 
 
 
 
 
 
 
118
  def _setup_routes(self):
119
  self.app.get("/status/")(self.get_status)
120
  self.app.post("/prepare/")(self.prepare)
121
  self.app.post("/upload-audio/")(self.upload_audio)
122
  self.app.post("/enhance/")(self.enhance_audio)
123
  self.app.get("/download-enhanced/")(self.download_enhanced)
124
-
 
125
  def get_status(self):
126
  try:
127
  return {"container_running": True}
128
  except Exception as e:
129
  logging.error(f"Error getting status: {e}")
130
- raise fastapi.HTTPException(status_code=500, detail="An error occurred while fetching API status.")
131
-
 
 
132
  def prepare(self):
133
  try:
134
  self._prepare()
135
- return {'preparations': True}
136
  except Exception as e:
137
  logging.error(f"Error during preparations: {e}")
138
- return fastapi.HTTPException(status_code=500, detail="An error occurred while fetching API status.")
139
-
 
 
140
  def upload_audio(self, files: List[fastapi.UploadFile] = fastapi.File(...)):
141
-
142
  uploaded_files = []
143
-
144
  for file in files:
145
- try:
146
  file_path = os.path.join(self.noisy_audio_path, file.filename)
147
-
148
  with open(file_path, "wb") as f:
149
- while contents := file.file.read(1024*1024):
150
  f.write(contents)
151
 
152
- uploaded_files.append(file.filename)
153
-
154
  except Exception as e:
155
- logging.error(f"Error uploading files: {e}")
156
- raise fastapi.HTTPException(status_code=500, detail="An error occurred while uploading the noisy files.")
 
 
 
157
  finally:
158
  file.file.close()
159
-
160
  print(f"uploaded files: {uploaded_files}")
161
-
162
  return {"uploaded_files": uploaded_files, "status": True}
163
 
164
  def enhance_audio(self):
@@ -166,35 +215,70 @@ class ModelAPI:
166
  # Enhance audio
167
  self._enhance()
168
  # Obtain list of file paths for enhanced audio
169
- wav_files = glob.glob(os.path.join(self.enhanced_audio_path, '*.wav'))
170
  # Extract just the file names
171
  enhanced_files = [os.path.basename(file) for file in wav_files]
172
  return {"status": True}
173
-
174
  except Exception as e:
175
  print(f"Exception occured during enhancement: {e}")
176
- raise fastapi.HTTPException(status_code=500, detail="An error occurred while enhancing the noisy files.")
177
-
 
 
 
178
  def download_enhanced(self):
179
  try:
180
  zip_buffer = io.BytesIO()
181
 
182
  with zipfile.ZipFile(zip_buffer, "w") as zip_file:
183
- for wav_file in glob.glob(os.path.join(self.enhanced_audio_path, '*.wav')):
 
 
184
  zip_file.write(wav_file, arcname=os.path.basename(wav_file))
185
  zip_buffer.seek(0)
186
 
187
-
188
  return fastapi.responses.StreamingResponse(
189
  iter([zip_buffer.getvalue()]), # Stream the in-memory content
190
  media_type="application/zip",
191
- headers={"Content-Disposition": "attachment; filename=enhanced_audio_files.zip"}
 
 
192
  )
193
 
194
  except Exception as e:
195
  logging.error(f"Error during enhanced files download: {e}")
196
- raise fastapi.HTTPException(status_code=500, detail=f"An error occurred while creating the download file: {str(e)}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
197
 
198
  def run(self):
199
-
200
- uvicorn.run(self.app, host=self.host, port=self.port)
 
1
+ import fastapi
2
+ import shutil
3
+ import os
4
+ import zipfile
5
+ import io
6
+ import uvicorn
7
  import threading
8
  import glob
9
+ from typing import List
10
  import torch
11
  import gdown
12
  from soundfile import write
13
  from torchaudio import load
14
  from librosa import resample
15
  import logging
16
+
17
  logging.basicConfig(level=logging.DEBUG)
18
 
19
  from sgmse import ScoreModel
20
  from sgmse.util.other import pad_spec
21
 
22
+
23
  class ModelAPI:
24
+
25
  def __init__(self, host, port):
26
+
27
+ self.host = host
28
  self.port = port
29
+
30
  self.base_path = os.path.join(os.path.expanduser("~"), ".modelapi")
31
  self.noisy_audio_path = os.path.join(self.base_path, "noisy_audio")
32
  self.enhanced_audio_path = os.path.join(self.base_path, "enhanced_audio")
33
  app_dir = os.path.dirname(os.path.abspath(__file__))
34
+ self.ckpt_path = glob.glob(os.path.join(app_dir, "*.ckpt"))[0]
35
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
36
  self.corrector = "ald"
37
  self.corrector_steps = 1
38
  self.snr = 0.5
 
41
  for audio_path in [self.noisy_audio_path, self.enhanced_audio_path]:
42
  if not os.path.exists(audio_path):
43
  os.makedirs(audio_path)
44
+
45
  for filename in os.listdir(audio_path):
46
  file_path = os.path.join(audio_path, filename)
47
+
48
  try:
49
  if os.path.isfile(file_path) or os.path.islink(file_path):
50
  os.unlink(file_path)
 
52
  shutil.rmtree(file_path)
53
  except Exception as e:
54
  raise e
55
+
56
  self.app = fastapi.FastAPI()
57
  self._setup_routes()
58
 
59
  def _prepare(self):
60
+ """Miners should modify this function to fit their fine-tuned models.
61
+
62
+ This function will make any preparations necessary to initialize the
63
+ speech enhancement model (i.e. downloading checkpoint files, etc.)
64
+ """
65
+
66
  self.model = ScoreModel.load_from_checkpoint(self.ckpt_path, self.device)
67
  self.model.t_eps = 0.03
68
  self.model.eval()
69
 
70
  def _enhance(self):
71
+ """
72
+ Miners should modify this function to fit their fine-tuned models.
73
+
74
+ This function will:
75
+ 1. Open each noisy .wav file
76
+ 2. Enhance the audio with the model
77
+ 3. Save the enhanced audio in .wav format to ModelAPI.enhanced_audio_path
78
+ """
79
+
80
+ if self.model.backbone == "ncsnpp_48k":
81
  target_sr = 48000
82
  pad_mode = "reflection"
83
+ elif self.model.backbone == "ncsnpp_v2":
84
  target_sr = 16000
85
  pad_mode = "reflection"
86
+ print("using ncsnpp_v2")
87
  else:
88
  target_sr = 16000
89
  pad_mode = "zero_pad"
90
 
91
+ noisy_files = sorted(glob.glob(os.path.join(self.noisy_audio_path, "*.wav")))
92
+ for noisy_file in noisy_files:
93
+
94
  filename = noisy_file.replace(self.noisy_audio_path, "")
95
  filename = filename[1:] if filename.startswith("/") else filename
96
 
 
99
  if sr != target_sr:
100
  y = torch.tensor(resample(y.numpy(), orig_sr=sr, target_sr=target_sr))
101
 
102
+ T_orig = y.size(1)
103
 
104
  # Normalize
105
  norm_factor = y.abs().max()
106
  y = y / norm_factor
107
+
108
  # Prepare DNN input
109
+ Y = torch.unsqueeze(
110
+ self.model._forward_transform(self.model._stft(y.to(self.device))), 0
111
+ )
112
  Y = pad_spec(Y, mode=pad_mode)
113
+
114
  # Reverse sampling
115
+ if self.model.sde.__class__.__name__ == "OUVESDE":
116
+ if self.model.sde.sampler_type == "pc":
117
+ sampler = self.model.get_pc_sampler(
118
+ "reverse_diffusion",
119
+ self.corrector,
120
+ Y.to(self.device),
121
+ N=self.N,
122
+ corrector_steps=self.corrector_steps,
123
+ snr=self.snr,
124
+ )
125
+ elif self.model.sde.sampler_type == "ode":
126
  sampler = self.model.get_ode_sampler(Y.to(self.device), N=self.N)
127
  else:
128
  raise ValueError(f"Sampler type {args.sampler_type} not supported")
129
+ elif self.model.sde.__class__.__name__ == "SBVESDE":
130
+ sampler_type = (
131
+ "ode"
132
+ if self.model.sde.sampler_type == "pc"
133
+ else self.model.sde.sampler_type
134
+ )
135
+ sampler = self.model.get_sb_sampler(
136
+ sde=self.model.sde, y=Y.cuda(), sampler_type=sampler_type
137
+ )
138
  else:
139
+ raise ValueError(
140
+ f"SDE {self.model.sde.__class__.__name__} not supported"
141
+ )
142
+
143
  sample, _ = sampler()
144
+
145
  x_hat = self.model.to_audio(sample.squeeze(), T_orig)
146
 
147
  x_hat = x_hat * norm_factor
148
+
149
+ os.makedirs(
150
+ os.path.dirname(os.path.join(self.enhanced_audio_path, filename)),
151
+ exist_ok=True,
152
+ )
153
+ write(
154
+ os.path.join(self.enhanced_audio_path, filename),
155
+ x_hat.cpu().numpy(),
156
+ target_sr,
157
+ )
158
+
159
  def _setup_routes(self):
160
  self.app.get("/status/")(self.get_status)
161
  self.app.post("/prepare/")(self.prepare)
162
  self.app.post("/upload-audio/")(self.upload_audio)
163
  self.app.post("/enhance/")(self.enhance_audio)
164
  self.app.get("/download-enhanced/")(self.download_enhanced)
165
+ self.app.post("/reset/")(self.reset)
166
+
167
  def get_status(self):
168
  try:
169
  return {"container_running": True}
170
  except Exception as e:
171
  logging.error(f"Error getting status: {e}")
172
+ raise fastapi.HTTPException(
173
+ status_code=500, detail="An error occurred while fetching API status."
174
+ )
175
+
176
  def prepare(self):
177
  try:
178
  self._prepare()
179
+ return {"preparations": True}
180
  except Exception as e:
181
  logging.error(f"Error during preparations: {e}")
182
+ return fastapi.HTTPException(
183
+ status_code=500, detail="An error occurred while fetching API status."
184
+ )
185
+
186
  def upload_audio(self, files: List[fastapi.UploadFile] = fastapi.File(...)):
187
+
188
  uploaded_files = []
189
+
190
  for file in files:
191
+ try:
192
  file_path = os.path.join(self.noisy_audio_path, file.filename)
193
+
194
  with open(file_path, "wb") as f:
195
+ while contents := file.file.read(1024 * 1024):
196
  f.write(contents)
197
 
198
+ uploaded_files.append(file.filename)
199
+
200
  except Exception as e:
201
+ logging.error(f"Error uploading files: {e}")
202
+ raise fastapi.HTTPException(
203
+ status_code=500,
204
+ detail="An error occurred while uploading the noisy files.",
205
+ )
206
  finally:
207
  file.file.close()
208
+
209
  print(f"uploaded files: {uploaded_files}")
210
+
211
  return {"uploaded_files": uploaded_files, "status": True}
212
 
213
  def enhance_audio(self):
 
215
  # Enhance audio
216
  self._enhance()
217
  # Obtain list of file paths for enhanced audio
218
+ wav_files = glob.glob(os.path.join(self.enhanced_audio_path, "*.wav"))
219
  # Extract just the file names
220
  enhanced_files = [os.path.basename(file) for file in wav_files]
221
  return {"status": True}
222
+
223
  except Exception as e:
224
  print(f"Exception occured during enhancement: {e}")
225
+ raise fastapi.HTTPException(
226
+ status_code=500,
227
+ detail="An error occurred while enhancing the noisy files.",
228
+ )
229
+
230
  def download_enhanced(self):
231
  try:
232
  zip_buffer = io.BytesIO()
233
 
234
  with zipfile.ZipFile(zip_buffer, "w") as zip_file:
235
+ for wav_file in glob.glob(
236
+ os.path.join(self.enhanced_audio_path, "*.wav")
237
+ ):
238
  zip_file.write(wav_file, arcname=os.path.basename(wav_file))
239
  zip_buffer.seek(0)
240
 
 
241
  return fastapi.responses.StreamingResponse(
242
  iter([zip_buffer.getvalue()]), # Stream the in-memory content
243
  media_type="application/zip",
244
+ headers={
245
+ "Content-Disposition": "attachment; filename=enhanced_audio_files.zip"
246
+ },
247
  )
248
 
249
  except Exception as e:
250
  logging.error(f"Error during enhanced files download: {e}")
251
+ raise fastapi.HTTPException(
252
+ status_code=500,
253
+ detail=f"An error occurred while creating the download file: {str(e)}",
254
+ )
255
+
256
+ def reset(self):
257
+ """
258
+ Removes all audio files in preparation for another batch of enhancement.
259
+ """
260
+ for directory in [self.noisy_audio_path, self.enhanced_audio_path]:
261
+ if not os.path.isdir(directory):
262
+ continue
263
+
264
+ for filename in os.listdir(directory):
265
+ filepath = os.path.join(directory, filename)
266
+ if os.path.isfile(filepath):
267
+ try:
268
+ os.remove(filepath)
269
+ except Exception as e:
270
+ print(f"Error removing {filepath}: {e}")
271
+ return {
272
+ "status": False,
273
+ "noisy": os.listdir(self.noisy_audio_path),
274
+ "enhanced": os.listdir(self.enhanced_audio_path),
275
+ }
276
+ return {
277
+ "status": True,
278
+ "noisy": os.listdir(self.noisy_audio_path),
279
+ "enhanced": os.listdir(self.enhanced_audio_path),
280
+ }
281
 
282
  def run(self):
283
+
284
+ uvicorn.run(self.app, host=self.host, port=self.port)