vera6 commited on
Commit
7fb65fd
·
verified ·
1 Parent(s): c5de362

Update app/app.py

Browse files
Files changed (1) hide show
  1. app/app.py +284 -200
app/app.py CHANGED
@@ -1,200 +1,284 @@
1
- import fastapi
2
- import shutil
3
- import os
4
- import zipfile
5
- import io
6
- import uvicorn
7
- import threading
8
- import glob
9
- from typing import List
10
- import torch
11
- import gdown
12
- from soundfile import write
13
- from torchaudio import load
14
- from librosa import resample
15
- import logging
16
- logging.basicConfig(level=logging.DEBUG)
17
-
18
- from sgmse import ScoreModel
19
- from sgmse.util.other import pad_spec
20
-
21
- class ModelAPI:
22
-
23
- def __init__(self, host, port):
24
-
25
- self.host = host
26
- self.port = port
27
-
28
- self.base_path = os.path.join(os.path.expanduser("~"), ".modelapi")
29
- self.noisy_audio_path = os.path.join(self.base_path, "noisy_audio")
30
- self.enhanced_audio_path = os.path.join(self.base_path, "enhanced_audio")
31
- app_dir = os.path.dirname(os.path.abspath(__file__))
32
- self.ckpt_path = os.path.join(app_dir,"miner_21.ckpt")
33
- self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
34
- self.corrector = "ald"
35
- self.corrector_steps = 1
36
- self.snr = 0.5
37
- self.N = 30
38
-
39
- for audio_path in [self.noisy_audio_path, self.enhanced_audio_path]:
40
- if not os.path.exists(audio_path):
41
- os.makedirs(audio_path)
42
-
43
- for filename in os.listdir(audio_path):
44
- file_path = os.path.join(audio_path, filename)
45
-
46
- try:
47
- if os.path.isfile(file_path) or os.path.islink(file_path):
48
- os.unlink(file_path)
49
- elif os.path.isdir(file_path):
50
- shutil.rmtree(file_path)
51
- except Exception as e:
52
- raise e
53
-
54
- self.app = fastapi.FastAPI()
55
- self._setup_routes()
56
-
57
- def _prepare(self):
58
- self.model = ScoreModel.load_from_checkpoint(self.ckpt_path, self.device)
59
- self.model.t_eps = 0.03
60
- self.model.eval()
61
-
62
- def _enhance(self):
63
- if self.model.backbone == 'ncsnpp_48k':
64
- target_sr = 48000
65
- pad_mode = "reflection"
66
- elif self.model.backbone == 'ncsnpp_v2':
67
- target_sr = 16000
68
- pad_mode = "reflection"
69
- else:
70
- target_sr = 16000
71
- pad_mode = "zero_pad"
72
-
73
- noisy_files = sorted(glob.glob(os.path.join(self.noisy_audio_path, '*.wav')))
74
- for noisy_file in noisy_files:
75
-
76
- filename = noisy_file.replace(self.noisy_audio_path, "")
77
- filename = filename[1:] if filename.startswith("/") else filename
78
-
79
- y, sr = load(noisy_file)
80
-
81
- if sr != target_sr:
82
- y = torch.tensor(resample(y.numpy(), orig_sr=sr, target_sr=target_sr))
83
-
84
- T_orig = y.size(1)
85
-
86
- # Normalize
87
- norm_factor = y.abs().max()
88
- y = y / norm_factor
89
-
90
- # Prepare DNN input
91
- Y = torch.unsqueeze(self.model._forward_transform(self.model._stft(y.to(self.device))), 0)
92
- Y = pad_spec(Y, mode=pad_mode)
93
-
94
- # Reverse sampling
95
- if self.model.sde.__class__.__name__ == 'OUVESDE':
96
- if self.model.sde.sampler_type == 'pc':
97
- sampler = self.model.get_pc_sampler('reverse_diffusion', self.corrector, Y.to(self.device), N=self.N,
98
- corrector_steps=self.corrector_steps, snr=self.snr)
99
- elif self.model.sde.sampler_type == 'ode':
100
- sampler = self.model.get_ode_sampler(Y.to(self.device), N=self.N)
101
- else:
102
- raise ValueError(f"Sampler type {args.sampler_type} not supported")
103
- elif self.model.sde.__class__.__name__ == 'SBVESDE':
104
- sampler_type = 'ode' if self.model.sde.sampler_type == 'pc' else self.model.sde.sampler_type
105
- sampler = self.model.get_sb_sampler(sde=self.model.sde, y=Y.cuda(), sampler_type=sampler_type)
106
- else:
107
- raise ValueError(f"SDE {self.model.sde.__class__.__name__} not supported")
108
-
109
- sample, _ = sampler()
110
-
111
- x_hat = self.model.to_audio(sample.squeeze(), T_orig)
112
-
113
- x_hat = x_hat * norm_factor
114
-
115
- os.makedirs(os.path.dirname(os.path.join(self.enhanced_audio_path, filename)), exist_ok=True)
116
- write(os.path.join(self.enhanced_audio_path, filename), x_hat.cpu().numpy(), target_sr)
117
-
118
- def _setup_routes(self):
119
- self.app.get("/status/")(self.get_status)
120
- self.app.post("/prepare/")(self.prepare)
121
- self.app.post("/upload-audio/")(self.upload_audio)
122
- self.app.post("/enhance/")(self.enhance_audio)
123
- self.app.get("/download-enhanced/")(self.download_enhanced)
124
-
125
- def get_status(self):
126
- try:
127
- return {"container_running": True}
128
- except Exception as e:
129
- logging.error(f"Error getting status: {e}")
130
- raise fastapi.HTTPException(status_code=500, detail="An error occurred while fetching API status.")
131
-
132
- def prepare(self):
133
- try:
134
- self._prepare()
135
- return {'preparations': True}
136
- except Exception as e:
137
- logging.error(f"Error during preparations: {e}")
138
- return fastapi.HTTPException(status_code=500, detail="An error occurred while fetching API status.")
139
-
140
- def upload_audio(self, files: List[fastapi.UploadFile] = fastapi.File(...)):
141
-
142
- uploaded_files = []
143
-
144
- for file in files:
145
- try:
146
- file_path = os.path.join(self.noisy_audio_path, file.filename)
147
-
148
- with open(file_path, "wb") as f:
149
- while contents := file.file.read(1024*1024):
150
- f.write(contents)
151
-
152
- uploaded_files.append(file.filename)
153
-
154
- except Exception as e:
155
- logging.error(f"Error uploading files: {e}")
156
- raise fastapi.HTTPException(status_code=500, detail="An error occurred while uploading the noisy files.")
157
- finally:
158
- file.file.close()
159
-
160
- print(f"uploaded files: {uploaded_files}")
161
-
162
- return {"uploaded_files": uploaded_files, "status": True}
163
-
164
- def enhance_audio(self):
165
- try:
166
- # Enhance audio
167
- self._enhance()
168
- # Obtain list of file paths for enhanced audio
169
- wav_files = glob.glob(os.path.join(self.enhanced_audio_path, '*.wav'))
170
- # Extract just the file names
171
- enhanced_files = [os.path.basename(file) for file in wav_files]
172
- return {"status": True}
173
-
174
- except Exception as e:
175
- print(f"Exception occured during enhancement: {e}")
176
- raise fastapi.HTTPException(status_code=500, detail="An error occurred while enhancing the noisy files.")
177
-
178
- def download_enhanced(self):
179
- try:
180
- zip_buffer = io.BytesIO()
181
-
182
- with zipfile.ZipFile(zip_buffer, "w") as zip_file:
183
- for wav_file in glob.glob(os.path.join(self.enhanced_audio_path, '*.wav')):
184
- zip_file.write(wav_file, arcname=os.path.basename(wav_file))
185
- zip_buffer.seek(0)
186
-
187
-
188
- return fastapi.responses.StreamingResponse(
189
- iter([zip_buffer.getvalue()]), # Stream the in-memory content
190
- media_type="application/zip",
191
- headers={"Content-Disposition": "attachment; filename=enhanced_audio_files.zip"}
192
- )
193
-
194
- except Exception as e:
195
- logging.error(f"Error during enhanced files download: {e}")
196
- raise fastapi.HTTPException(status_code=500, detail=f"An error occurred while creating the download file: {str(e)}")
197
-
198
- def run(self):
199
-
200
- uvicorn.run(self.app, host=self.host, port=self.port)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import fastapi
2
+ import shutil
3
+ import os
4
+ import zipfile
5
+ import io
6
+ import uvicorn
7
+ import threading
8
+ import glob
9
+ from typing import List
10
+ import torch
11
+ import gdown
12
+ from soundfile import write
13
+ from torchaudio import load
14
+ from librosa import resample
15
+ import logging
16
+
17
+ logging.basicConfig(level=logging.DEBUG)
18
+
19
+ from sgmse import ScoreModel
20
+ from sgmse.util.other import pad_spec
21
+
22
+
23
+ class ModelAPI:
24
+
25
+ def __init__(self, host, port):
26
+
27
+ self.host = host
28
+ self.port = port
29
+
30
+ self.base_path = os.path.join(os.path.expanduser("~"), ".modelapi")
31
+ self.noisy_audio_path = os.path.join(self.base_path, "noisy_audio")
32
+ self.enhanced_audio_path = os.path.join(self.base_path, "enhanced_audio")
33
+ app_dir = os.path.dirname(os.path.abspath(__file__))
34
+ self.ckpt_path = glob.glob(os.path.join(app_dir, "*.ckpt"))[0]
35
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
36
+ self.corrector = "ald"
37
+ self.corrector_steps = 1
38
+ self.snr = 0.5
39
+ self.N = 30
40
+
41
+ for audio_path in [self.noisy_audio_path, self.enhanced_audio_path]:
42
+ if not os.path.exists(audio_path):
43
+ os.makedirs(audio_path)
44
+
45
+ for filename in os.listdir(audio_path):
46
+ file_path = os.path.join(audio_path, filename)
47
+
48
+ try:
49
+ if os.path.isfile(file_path) or os.path.islink(file_path):
50
+ os.unlink(file_path)
51
+ elif os.path.isdir(file_path):
52
+ shutil.rmtree(file_path)
53
+ except Exception as e:
54
+ raise e
55
+
56
+ self.app = fastapi.FastAPI()
57
+ self._setup_routes()
58
+
59
+ def _prepare(self):
60
+ """Miners should modify this function to fit their fine-tuned models.
61
+
62
+ This function will make any preparations necessary to initialize the
63
+ speech enhancement model (i.e. downloading checkpoint files, etc.)
64
+ """
65
+
66
+ self.model = ScoreModel.load_from_checkpoint(self.ckpt_path, self.device)
67
+ self.model.t_eps = 0.03
68
+ self.model.eval()
69
+
70
+ def _enhance(self):
71
+ """
72
+ Miners should modify this function to fit their fine-tuned models.
73
+
74
+ This function will:
75
+ 1. Open each noisy .wav file
76
+ 2. Enhance the audio with the model
77
+ 3. Save the enhanced audio in .wav format to ModelAPI.enhanced_audio_path
78
+ """
79
+
80
+ if self.model.backbone == "ncsnpp_48k":
81
+ target_sr = 48000
82
+ pad_mode = "reflection"
83
+ elif self.model.backbone == "ncsnpp_v2":
84
+ target_sr = 16000
85
+ pad_mode = "reflection"
86
+ print("using ncsnpp_v2")
87
+ else:
88
+ target_sr = 16000
89
+ pad_mode = "zero_pad"
90
+
91
+ noisy_files = sorted(glob.glob(os.path.join(self.noisy_audio_path, "*.wav")))
92
+ for noisy_file in noisy_files:
93
+
94
+ filename = noisy_file.replace(self.noisy_audio_path, "")
95
+ filename = filename[1:] if filename.startswith("/") else filename
96
+
97
+ y, sr = load(noisy_file)
98
+
99
+ if sr != target_sr:
100
+ y = torch.tensor(resample(y.numpy(), orig_sr=sr, target_sr=target_sr))
101
+
102
+ T_orig = y.size(1)
103
+
104
+ # Normalize
105
+ norm_factor = y.abs().max()
106
+ y = y / norm_factor
107
+
108
+ # Prepare DNN input
109
+ Y = torch.unsqueeze(
110
+ self.model._forward_transform(self.model._stft(y.to(self.device))), 0
111
+ )
112
+ Y = pad_spec(Y, mode=pad_mode)
113
+
114
+ # Reverse sampling
115
+ if self.model.sde.__class__.__name__ == "OUVESDE":
116
+ if self.model.sde.sampler_type == "pc":
117
+ sampler = self.model.get_pc_sampler(
118
+ "reverse_diffusion",
119
+ self.corrector,
120
+ Y.to(self.device),
121
+ N=self.N,
122
+ corrector_steps=self.corrector_steps,
123
+ snr=self.snr,
124
+ )
125
+ elif self.model.sde.sampler_type == "ode":
126
+ sampler = self.model.get_ode_sampler(Y.to(self.device), N=self.N)
127
+ else:
128
+ raise ValueError(f"Sampler type {args.sampler_type} not supported")
129
+ elif self.model.sde.__class__.__name__ == "SBVESDE":
130
+ sampler_type = (
131
+ "ode"
132
+ if self.model.sde.sampler_type == "pc"
133
+ else self.model.sde.sampler_type
134
+ )
135
+ sampler = self.model.get_sb_sampler(
136
+ sde=self.model.sde, y=Y.cuda(), sampler_type=sampler_type
137
+ )
138
+ else:
139
+ raise ValueError(
140
+ f"SDE {self.model.sde.__class__.__name__} not supported"
141
+ )
142
+
143
+ sample, _ = sampler()
144
+
145
+ x_hat = self.model.to_audio(sample.squeeze(), T_orig)
146
+
147
+ x_hat = x_hat * norm_factor
148
+
149
+ os.makedirs(
150
+ os.path.dirname(os.path.join(self.enhanced_audio_path, filename)),
151
+ exist_ok=True,
152
+ )
153
+ write(
154
+ os.path.join(self.enhanced_audio_path, filename),
155
+ x_hat.cpu().numpy(),
156
+ target_sr,
157
+ )
158
+
159
+ def _setup_routes(self):
160
+ self.app.get("/status/")(self.get_status)
161
+ self.app.post("/prepare/")(self.prepare)
162
+ self.app.post("/upload-audio/")(self.upload_audio)
163
+ self.app.post("/enhance/")(self.enhance_audio)
164
+ self.app.get("/download-enhanced/")(self.download_enhanced)
165
+ self.app.post("/reset/")(self.reset)
166
+
167
+ def get_status(self):
168
+ try:
169
+ return {"container_running": True}
170
+ except Exception as e:
171
+ logging.error(f"Error getting status: {e}")
172
+ raise fastapi.HTTPException(
173
+ status_code=500, detail="An error occurred while fetching API status."
174
+ )
175
+
176
+ def prepare(self):
177
+ try:
178
+ self._prepare()
179
+ return {"preparations": True}
180
+ except Exception as e:
181
+ logging.error(f"Error during preparations: {e}")
182
+ return fastapi.HTTPException(
183
+ status_code=500, detail="An error occurred while fetching API status."
184
+ )
185
+
186
+ def upload_audio(self, files: List[fastapi.UploadFile] = fastapi.File(...)):
187
+
188
+ uploaded_files = []
189
+
190
+ for file in files:
191
+ try:
192
+ file_path = os.path.join(self.noisy_audio_path, file.filename)
193
+
194
+ with open(file_path, "wb") as f:
195
+ while contents := file.file.read(1024 * 1024):
196
+ f.write(contents)
197
+
198
+ uploaded_files.append(file.filename)
199
+
200
+ except Exception as e:
201
+ logging.error(f"Error uploading files: {e}")
202
+ raise fastapi.HTTPException(
203
+ status_code=500,
204
+ detail="An error occurred while uploading the noisy files.",
205
+ )
206
+ finally:
207
+ file.file.close()
208
+
209
+ print(f"uploaded files: {uploaded_files}")
210
+
211
+ return {"uploaded_files": uploaded_files, "status": True}
212
+
213
+ def enhance_audio(self):
214
+ try:
215
+ # Enhance audio
216
+ self._enhance()
217
+ # Obtain list of file paths for enhanced audio
218
+ wav_files = glob.glob(os.path.join(self.enhanced_audio_path, "*.wav"))
219
+ # Extract just the file names
220
+ enhanced_files = [os.path.basename(file) for file in wav_files]
221
+ return {"status": True}
222
+
223
+ except Exception as e:
224
+ print(f"Exception occured during enhancement: {e}")
225
+ raise fastapi.HTTPException(
226
+ status_code=500,
227
+ detail="An error occurred while enhancing the noisy files.",
228
+ )
229
+
230
+ def download_enhanced(self):
231
+ try:
232
+ zip_buffer = io.BytesIO()
233
+
234
+ with zipfile.ZipFile(zip_buffer, "w") as zip_file:
235
+ for wav_file in glob.glob(
236
+ os.path.join(self.enhanced_audio_path, "*.wav")
237
+ ):
238
+ zip_file.write(wav_file, arcname=os.path.basename(wav_file))
239
+ zip_buffer.seek(0)
240
+
241
+ return fastapi.responses.StreamingResponse(
242
+ iter([zip_buffer.getvalue()]), # Stream the in-memory content
243
+ media_type="application/zip",
244
+ headers={
245
+ "Content-Disposition": "attachment; filename=enhanced_audio_files.zip"
246
+ },
247
+ )
248
+
249
+ except Exception as e:
250
+ logging.error(f"Error during enhanced files download: {e}")
251
+ raise fastapi.HTTPException(
252
+ status_code=500,
253
+ detail=f"An error occurred while creating the download file: {str(e)}",
254
+ )
255
+
256
+ def reset(self):
257
+ """
258
+ Removes all audio files in preparation for another batch of enhancement.
259
+ """
260
+ for directory in [self.noisy_audio_path, self.enhanced_audio_path]:
261
+ if not os.path.isdir(directory):
262
+ continue
263
+
264
+ for filename in os.listdir(directory):
265
+ filepath = os.path.join(directory, filename)
266
+ if os.path.isfile(filepath):
267
+ try:
268
+ os.remove(filepath)
269
+ except Exception as e:
270
+ print(f"Error removing {filepath}: {e}")
271
+ return {
272
+ "status": False,
273
+ "noisy": os.listdir(self.noisy_audio_path),
274
+ "enhanced": os.listdir(self.enhanced_audio_path),
275
+ }
276
+ return {
277
+ "status": True,
278
+ "noisy": os.listdir(self.noisy_audio_path),
279
+ "enhanced": os.listdir(self.enhanced_audio_path),
280
+ }
281
+
282
+ def run(self):
283
+
284
+ uvicorn.run(self.app, host=self.host, port=self.port)