vera6 commited on
Commit
a33ea8e
·
verified ·
1 Parent(s): 8d6da92

Upload folder using huggingface_hub

Browse files
Files changed (3) hide show
  1. .gitattributes +63 -35
  2. app/app.py +284 -284
  3. app/miner_32.ckpt +2 -2
.gitattributes CHANGED
@@ -1,35 +1,63 @@
1
- *.7z filter=lfs diff=lfs merge=lfs -text
2
- *.arrow filter=lfs diff=lfs merge=lfs -text
3
- *.bin filter=lfs diff=lfs merge=lfs -text
4
- *.bz2 filter=lfs diff=lfs merge=lfs -text
5
- *.ckpt filter=lfs diff=lfs merge=lfs -text
6
- *.ftz filter=lfs diff=lfs merge=lfs -text
7
- *.gz filter=lfs diff=lfs merge=lfs -text
8
- *.h5 filter=lfs diff=lfs merge=lfs -text
9
- *.joblib filter=lfs diff=lfs merge=lfs -text
10
- *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
- *.model filter=lfs diff=lfs merge=lfs -text
13
- *.msgpack filter=lfs diff=lfs merge=lfs -text
14
- *.npy filter=lfs diff=lfs merge=lfs -text
15
- *.npz filter=lfs diff=lfs merge=lfs -text
16
- *.onnx filter=lfs diff=lfs merge=lfs -text
17
- *.ot filter=lfs diff=lfs merge=lfs -text
18
- *.parquet filter=lfs diff=lfs merge=lfs -text
19
- *.pb filter=lfs diff=lfs merge=lfs -text
20
- *.pickle filter=lfs diff=lfs merge=lfs -text
21
- *.pkl filter=lfs diff=lfs merge=lfs -text
22
- *.pt filter=lfs diff=lfs merge=lfs -text
23
- *.pth filter=lfs diff=lfs merge=lfs -text
24
- *.rar filter=lfs diff=lfs merge=lfs -text
25
- *.safetensors filter=lfs diff=lfs merge=lfs -text
26
- saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
- *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
- *.tflite filter=lfs diff=lfs merge=lfs -text
30
- *.tgz filter=lfs diff=lfs merge=lfs -text
31
- *.wasm filter=lfs diff=lfs merge=lfs -text
32
- *.xz filter=lfs diff=lfs merge=lfs -text
33
- *.zip filter=lfs diff=lfs merge=lfs -text
34
- *.zst filter=lfs diff=lfs merge=lfs -text
35
- *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ .venv/Scripts/tiny-agents.exe filter=lfs diff=lfs merge=lfs -text
37
+ .venv/Scripts/normalizer.exe filter=lfs diff=lfs merge=lfs -text
38
+ .venv/Scripts/huggingface-cli.exe filter=lfs diff=lfs merge=lfs -text
39
+ .venv/Scripts/tqdm.exe filter=lfs diff=lfs merge=lfs -text
40
+ .venv/Scripts/pip.exe filter=lfs diff=lfs merge=lfs -text
41
+ .venv/Lib/site-packages/charset_normalizer/md__mypyc.cp312-win_amd64.pyd filter=lfs diff=lfs merge=lfs -text
42
+ .venv/Scripts/pip3.12.exe filter=lfs diff=lfs merge=lfs -text
43
+ .venv/Scripts/pip3.exe filter=lfs diff=lfs merge=lfs -text
44
+ .venv/Scripts/hf.exe filter=lfs diff=lfs merge=lfs -text
45
+ .venv/Lib/site-packages/pip/_vendor/rich/__pycache__/console.cpython-312.pyc filter=lfs diff=lfs merge=lfs -text
46
+ .venv/Lib/site-packages/prompt_toolkit/layout/__pycache__/containers.cpython-312.pyc filter=lfs diff=lfs merge=lfs -text
47
+ .venv/Lib/site-packages/prompt_toolkit/key_binding/bindings/__pycache__/vi.cpython-312.pyc filter=lfs diff=lfs merge=lfs -text
48
+ .venv/Lib/site-packages/pip/_vendor/distlib/t64.exe filter=lfs diff=lfs merge=lfs -text
49
+ .venv/Lib/site-packages/pip/_vendor/distlib/w64.exe filter=lfs diff=lfs merge=lfs -text
50
+ .venv/Scripts/python.exe filter=lfs diff=lfs merge=lfs -text
51
+ .venv/Lib/site-packages/yaml/_yaml.cp312-win_amd64.pyd filter=lfs diff=lfs merge=lfs -text
52
+ .venv/Lib/site-packages/pip/_vendor/idna/__pycache__/uts46data.cpython-312.pyc filter=lfs diff=lfs merge=lfs -text
53
+ .venv/Lib/site-packages/pip/_vendor/pkg_resources/__pycache__/__init__.cpython-312.pyc filter=lfs diff=lfs merge=lfs -text
54
+ .venv/Lib/site-packages/idna/__pycache__/uts46data.cpython-312.pyc filter=lfs diff=lfs merge=lfs -text
55
+ .venv/Lib/site-packages/huggingface_hub/inference/__pycache__/_client.cpython-312.pyc filter=lfs diff=lfs merge=lfs -text
56
+ .venv/Lib/site-packages/pip/_vendor/distlib/t64-arm.exe filter=lfs diff=lfs merge=lfs -text
57
+ .venv/Lib/site-packages/pip/_vendor/distlib/w64-arm.exe filter=lfs diff=lfs merge=lfs -text
58
+ .venv/Scripts/pythonw.exe filter=lfs diff=lfs merge=lfs -text
59
+ .venv/Lib/site-packages/huggingface_hub/inference/_generated/__pycache__/_async_client.cpython-312.pyc.2109505304720 filter=lfs diff=lfs merge=lfs -text
60
+ .venv/Lib/site-packages/pip/_vendor/rich/__pycache__/_emoji_codes.cpython-312.pyc filter=lfs diff=lfs merge=lfs -text
61
+ .venv/Lib/site-packages/huggingface_hub/__pycache__/hf_api.cpython-312.pyc filter=lfs diff=lfs merge=lfs -text
62
+ .venv/Lib/site-packages/__pycache__/typing_extensions.cpython-312.pyc filter=lfs diff=lfs merge=lfs -text
63
+ .venv/Lib/site-packages/hf_transfer/hf_transfer.pyd filter=lfs diff=lfs merge=lfs -text
app/app.py CHANGED
@@ -1,284 +1,284 @@
1
- import fastapi
2
- import shutil
3
- import os
4
- import zipfile
5
- import io
6
- import uvicorn
7
- import threading
8
- import glob
9
- from typing import List
10
- import torch
11
- import gdown
12
- from soundfile import write
13
- from torchaudio import load
14
- from librosa import resample
15
- import logging
16
-
17
- logging.basicConfig(level=logging.DEBUG)
18
-
19
- from sgmse import ScoreModel
20
- from sgmse.util.other import pad_spec
21
-
22
-
23
- class ModelAPI:
24
-
25
- def __init__(self, host, port):
26
-
27
- self.host = host
28
- self.port = port
29
-
30
- self.base_path = os.path.join(os.path.expanduser("~"), ".modelapi")
31
- self.noisy_audio_path = os.path.join(self.base_path, "noisy_audio")
32
- self.enhanced_audio_path = os.path.join(self.base_path, "enhanced_audio")
33
- app_dir = os.path.dirname(os.path.abspath(__file__))
34
- self.ckpt_path = glob.glob(os.path.join(app_dir, "*.ckpt"))[0]
35
- self.device = "cuda" if torch.cuda.is_available() else "cpu"
36
- self.corrector = "ald"
37
- self.corrector_steps = 1
38
- self.snr = 0.5
39
- self.N = 30
40
-
41
- for audio_path in [self.noisy_audio_path, self.enhanced_audio_path]:
42
- if not os.path.exists(audio_path):
43
- os.makedirs(audio_path)
44
-
45
- for filename in os.listdir(audio_path):
46
- file_path = os.path.join(audio_path, filename)
47
-
48
- try:
49
- if os.path.isfile(file_path) or os.path.islink(file_path):
50
- os.unlink(file_path)
51
- elif os.path.isdir(file_path):
52
- shutil.rmtree(file_path)
53
- except Exception as e:
54
- raise e
55
-
56
- self.app = fastapi.FastAPI()
57
- self._setup_routes()
58
-
59
- def _prepare(self):
60
- """Miners should modify this function to fit their fine-tuned models.
61
-
62
- This function will make any preparations necessary to initialize the
63
- speech enhancement model (i.e. downloading checkpoint files, etc.)
64
- """
65
-
66
- self.model = ScoreModel.load_from_checkpoint(self.ckpt_path, self.device)
67
- self.model.t_eps = 0.03
68
- self.model.eval()
69
-
70
- def _enhance(self):
71
- """
72
- Miners should modify this function to fit their fine-tuned models.
73
-
74
- This function will:
75
- 1. Open each noisy .wav file
76
- 2. Enhance the audio with the model
77
- 3. Save the enhanced audio in .wav format to ModelAPI.enhanced_audio_path
78
- """
79
-
80
- if self.model.backbone == "ncsnpp_48k":
81
- target_sr = 48000
82
- pad_mode = "reflection"
83
- elif self.model.backbone == "ncsnpp_v2":
84
- target_sr = 16000
85
- pad_mode = "reflection"
86
- print("using ncsnpp_v2")
87
- else:
88
- target_sr = 16000
89
- pad_mode = "zero_pad"
90
-
91
- noisy_files = sorted(glob.glob(os.path.join(self.noisy_audio_path, "*.wav")))
92
- for noisy_file in noisy_files:
93
-
94
- filename = noisy_file.replace(self.noisy_audio_path, "")
95
- filename = filename[1:] if filename.startswith("/") else filename
96
-
97
- y, sr = load(noisy_file)
98
-
99
- if sr != target_sr:
100
- y = torch.tensor(resample(y.numpy(), orig_sr=sr, target_sr=target_sr))
101
-
102
- T_orig = y.size(1)
103
-
104
- # Normalize
105
- norm_factor = y.abs().max()
106
- y = y / norm_factor
107
-
108
- # Prepare DNN input
109
- Y = torch.unsqueeze(
110
- self.model._forward_transform(self.model._stft(y.to(self.device))), 0
111
- )
112
- Y = pad_spec(Y, mode=pad_mode)
113
-
114
- # Reverse sampling
115
- if self.model.sde.__class__.__name__ == "OUVESDE":
116
- if self.model.sde.sampler_type == "pc":
117
- sampler = self.model.get_pc_sampler(
118
- "reverse_diffusion",
119
- self.corrector,
120
- Y.to(self.device),
121
- N=self.N,
122
- corrector_steps=self.corrector_steps,
123
- snr=self.snr,
124
- )
125
- elif self.model.sde.sampler_type == "ode":
126
- sampler = self.model.get_ode_sampler(Y.to(self.device), N=self.N)
127
- else:
128
- raise ValueError(f"Sampler type {args.sampler_type} not supported")
129
- elif self.model.sde.__class__.__name__ == "SBVESDE":
130
- sampler_type = (
131
- "ode"
132
- if self.model.sde.sampler_type == "pc"
133
- else self.model.sde.sampler_type
134
- )
135
- sampler = self.model.get_sb_sampler(
136
- sde=self.model.sde, y=Y.cuda(), sampler_type=sampler_type
137
- )
138
- else:
139
- raise ValueError(
140
- f"SDE {self.model.sde.__class__.__name__} not supported"
141
- )
142
-
143
- sample, _ = sampler()
144
-
145
- x_hat = self.model.to_audio(sample.squeeze(), T_orig)
146
-
147
- x_hat = x_hat * norm_factor
148
-
149
- os.makedirs(
150
- os.path.dirname(os.path.join(self.enhanced_audio_path, filename)),
151
- exist_ok=True,
152
- )
153
- write(
154
- os.path.join(self.enhanced_audio_path, filename),
155
- x_hat.cpu().numpy(),
156
- target_sr,
157
- )
158
-
159
- def _setup_routes(self):
160
- self.app.get("/status/")(self.get_status)
161
- self.app.post("/prepare/")(self.prepare)
162
- self.app.post("/upload-audio/")(self.upload_audio)
163
- self.app.post("/enhance/")(self.enhance_audio)
164
- self.app.get("/download-enhanced/")(self.download_enhanced)
165
- self.app.post("/reset/")(self.reset)
166
-
167
- def get_status(self):
168
- try:
169
- return {"container_running": True}
170
- except Exception as e:
171
- logging.error(f"Error getting status: {e}")
172
- raise fastapi.HTTPException(
173
- status_code=500, detail="An error occurred while fetching API status."
174
- )
175
-
176
- def prepare(self):
177
- try:
178
- self._prepare()
179
- return {"preparations": True}
180
- except Exception as e:
181
- logging.error(f"Error during preparations: {e}")
182
- return fastapi.HTTPException(
183
- status_code=500, detail="An error occurred while fetching API status."
184
- )
185
-
186
- def upload_audio(self, files: List[fastapi.UploadFile] = fastapi.File(...)):
187
-
188
- uploaded_files = []
189
-
190
- for file in files:
191
- try:
192
- file_path = os.path.join(self.noisy_audio_path, file.filename)
193
-
194
- with open(file_path, "wb") as f:
195
- while contents := file.file.read(1024 * 1024):
196
- f.write(contents)
197
-
198
- uploaded_files.append(file.filename)
199
-
200
- except Exception as e:
201
- logging.error(f"Error uploading files: {e}")
202
- raise fastapi.HTTPException(
203
- status_code=500,
204
- detail="An error occurred while uploading the noisy files.",
205
- )
206
- finally:
207
- file.file.close()
208
-
209
- print(f"uploaded files: {uploaded_files}")
210
-
211
- return {"uploaded_files": uploaded_files, "status": True}
212
-
213
- def enhance_audio(self):
214
- try:
215
- # Enhance audio
216
- self._enhance()
217
- # Obtain list of file paths for enhanced audio
218
- wav_files = glob.glob(os.path.join(self.enhanced_audio_path, "*.wav"))
219
- # Extract just the file names
220
- enhanced_files = [os.path.basename(file) for file in wav_files]
221
- return {"status": True}
222
-
223
- except Exception as e:
224
- print(f"Exception occured during enhancement: {e}")
225
- raise fastapi.HTTPException(
226
- status_code=500,
227
- detail="An error occurred while enhancing the noisy files.",
228
- )
229
-
230
- def download_enhanced(self):
231
- try:
232
- zip_buffer = io.BytesIO()
233
-
234
- with zipfile.ZipFile(zip_buffer, "w") as zip_file:
235
- for wav_file in glob.glob(
236
- os.path.join(self.enhanced_audio_path, "*.wav")
237
- ):
238
- zip_file.write(wav_file, arcname=os.path.basename(wav_file))
239
- zip_buffer.seek(0)
240
-
241
- return fastapi.responses.StreamingResponse(
242
- iter([zip_buffer.getvalue()]), # Stream the in-memory content
243
- media_type="application/zip",
244
- headers={
245
- "Content-Disposition": "attachment; filename=enhanced_audio_files.zip"
246
- },
247
- )
248
-
249
- except Exception as e:
250
- logging.error(f"Error during enhanced files download: {e}")
251
- raise fastapi.HTTPException(
252
- status_code=500,
253
- detail=f"An error occurred while creating the download file: {str(e)}",
254
- )
255
-
256
- def reset(self):
257
- """
258
- Removes all audio files in preparation for another batch of enhancement.
259
- """
260
- for directory in [self.noisy_audio_path, self.enhanced_audio_path]:
261
- if not os.path.isdir(directory):
262
- continue
263
-
264
- for filename in os.listdir(directory):
265
- filepath = os.path.join(directory, filename)
266
- if os.path.isfile(filepath):
267
- try:
268
- os.remove(filepath)
269
- except Exception as e:
270
- print(f"Error removing {filepath}: {e}")
271
- return {
272
- "status": False,
273
- "noisy": os.listdir(self.noisy_audio_path),
274
- "enhanced": os.listdir(self.enhanced_audio_path),
275
- }
276
- return {
277
- "status": True,
278
- "noisy": os.listdir(self.noisy_audio_path),
279
- "enhanced": os.listdir(self.enhanced_audio_path),
280
- }
281
-
282
- def run(self):
283
-
284
- uvicorn.run(self.app, host=self.host, port=self.port)
 
1
+ import fastapi
2
+ import shutil
3
+ import os
4
+ import zipfile
5
+ import io
6
+ import uvicorn
7
+ import threading
8
+ import glob
9
+ from typing import List
10
+ import torch
11
+ import gdown
12
+ from soundfile import write
13
+ from torchaudio import load
14
+ from librosa import resample
15
+ import logging
16
+
17
+ logging.basicConfig(level=logging.DEBUG)
18
+
19
+ from sgmse import ScoreModel
20
+ from sgmse.util.other import pad_spec
21
+
22
+
23
+ class ModelAPI:
24
+
25
+ def __init__(self, host, port):
26
+
27
+ self.host = host
28
+ self.port = port
29
+
30
+ self.base_path = os.path.join(os.path.expanduser("~"), ".modelapi")
31
+ self.noisy_audio_path = os.path.join(self.base_path, "noisy_audio")
32
+ self.enhanced_audio_path = os.path.join(self.base_path, "enhanced_audio")
33
+ app_dir = os.path.dirname(os.path.abspath(__file__))
34
+ self.ckpt_path = glob.glob(os.path.join(app_dir, "*.ckpt"))[0]
35
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
36
+ self.corrector = "ald"
37
+ self.corrector_steps = 1
38
+ self.snr = 0.5
39
+ self.N = 30
40
+
41
+ for audio_path in [self.noisy_audio_path, self.enhanced_audio_path]:
42
+ if not os.path.exists(audio_path):
43
+ os.makedirs(audio_path)
44
+
45
+ for filename in os.listdir(audio_path):
46
+ file_path = os.path.join(audio_path, filename)
47
+
48
+ try:
49
+ if os.path.isfile(file_path) or os.path.islink(file_path):
50
+ os.unlink(file_path)
51
+ elif os.path.isdir(file_path):
52
+ shutil.rmtree(file_path)
53
+ except Exception as e:
54
+ raise e
55
+
56
+ self.app = fastapi.FastAPI()
57
+ self._setup_routes()
58
+
59
+ def _prepare(self):
60
+ """Miners should modify this function to fit their fine-tuned models.
61
+
62
+ This function will make any preparations necessary to initialize the
63
+ speech enhancement model (i.e. downloading checkpoint files, etc.)
64
+ """
65
+
66
+ self.model = ScoreModel.load_from_checkpoint(self.ckpt_path, self.device)
67
+ self.model.t_eps = 0.03
68
+ self.model.eval()
69
+
70
+ def _enhance(self):
71
+ """
72
+ Miners should modify this function to fit their fine-tuned models.
73
+
74
+ This function will:
75
+ 1. Open each noisy .wav file
76
+ 2. Enhance the audio with the model
77
+ 3. Save the enhanced audio in .wav format to ModelAPI.enhanced_audio_path
78
+ """
79
+
80
+ if self.model.backbone == "ncsnpp_48k":
81
+ target_sr = 48000
82
+ pad_mode = "reflection"
83
+ elif self.model.backbone == "ncsnpp_v2":
84
+ target_sr = 16000
85
+ pad_mode = "reflection"
86
+ print("using ncsnpp_v2")
87
+ else:
88
+ target_sr = 16000
89
+ pad_mode = "zero_pad"
90
+
91
+ noisy_files = sorted(glob.glob(os.path.join(self.noisy_audio_path, "*.wav")))
92
+ for noisy_file in noisy_files:
93
+
94
+ filename = noisy_file.replace(self.noisy_audio_path, "")
95
+ filename = filename[1:] if filename.startswith("/") else filename
96
+
97
+ y, sr = load(noisy_file)
98
+
99
+ if sr != target_sr:
100
+ y = torch.tensor(resample(y.numpy(), orig_sr=sr, target_sr=target_sr))
101
+
102
+ T_orig = y.size(1)
103
+
104
+ # Normalize
105
+ norm_factor = y.abs().max()
106
+ y = y / norm_factor
107
+
108
+ # Prepare DNN input
109
+ Y = torch.unsqueeze(
110
+ self.model._forward_transform(self.model._stft(y.to(self.device))), 0
111
+ )
112
+ Y = pad_spec(Y, mode=pad_mode)
113
+
114
+ # Reverse sampling
115
+ if self.model.sde.__class__.__name__ == "OUVESDE":
116
+ if self.model.sde.sampler_type == "pc":
117
+ sampler = self.model.get_pc_sampler(
118
+ "reverse_diffusion",
119
+ self.corrector,
120
+ Y.to(self.device),
121
+ N=self.N,
122
+ corrector_steps=self.corrector_steps,
123
+ snr=self.snr,
124
+ )
125
+ elif self.model.sde.sampler_type == "ode":
126
+ sampler = self.model.get_ode_sampler(Y.to(self.device), N=self.N)
127
+ else:
128
+ raise ValueError(f"Sampler type {args.sampler_type} not supported")
129
+ elif self.model.sde.__class__.__name__ == "SBVESDE":
130
+ sampler_type = (
131
+ "ode"
132
+ if self.model.sde.sampler_type == "pc"
133
+ else self.model.sde.sampler_type
134
+ )
135
+ sampler = self.model.get_sb_sampler(
136
+ sde=self.model.sde, y=Y.cuda(), sampler_type=sampler_type
137
+ )
138
+ else:
139
+ raise ValueError(
140
+ f"SDE {self.model.sde.__class__.__name__} not supported"
141
+ )
142
+
143
+ sample, _ = sampler()
144
+
145
+ x_hat = self.model.to_audio(sample.squeeze(), T_orig)
146
+
147
+ x_hat = x_hat * norm_factor
148
+
149
+ os.makedirs(
150
+ os.path.dirname(os.path.join(self.enhanced_audio_path, filename)),
151
+ exist_ok=True,
152
+ )
153
+ write(
154
+ os.path.join(self.enhanced_audio_path, filename),
155
+ x_hat.cpu().numpy(),
156
+ target_sr,
157
+ )
158
+
159
+ def _setup_routes(self):
160
+ self.app.get("/status/")(self.get_status)
161
+ self.app.post("/prepare/")(self.prepare)
162
+ self.app.post("/upload-audio/")(self.upload_audio)
163
+ self.app.post("/enhance/")(self.enhance_audio)
164
+ self.app.get("/download-enhanced/")(self.download_enhanced)
165
+ self.app.post("/reset/")(self.reset)
166
+
167
+ def get_status(self):
168
+ try:
169
+ return {"container_running": True}
170
+ except Exception as e:
171
+ logging.error(f"Error getting status: {e}")
172
+ raise fastapi.HTTPException(
173
+ status_code=500, detail="An error occurred while fetching API status."
174
+ )
175
+
176
+ def prepare(self):
177
+ try:
178
+ self._prepare()
179
+ return {"preparations": True}
180
+ except Exception as e:
181
+ logging.error(f"Error during preparations: {e}")
182
+ return fastapi.HTTPException(
183
+ status_code=500, detail="An error occurred while fetching API status."
184
+ )
185
+
186
+ def upload_audio(self, files: List[fastapi.UploadFile] = fastapi.File(...)):
187
+
188
+ uploaded_files = []
189
+
190
+ for file in files:
191
+ try:
192
+ file_path = os.path.join(self.noisy_audio_path, file.filename)
193
+
194
+ with open(file_path, "wb") as f:
195
+ while contents := file.file.read(1024 * 1024):
196
+ f.write(contents)
197
+
198
+ uploaded_files.append(file.filename)
199
+
200
+ except Exception as e:
201
+ logging.error(f"Error uploading files: {e}")
202
+ raise fastapi.HTTPException(
203
+ status_code=500,
204
+ detail="An error occurred while uploading the noisy files.",
205
+ )
206
+ finally:
207
+ file.file.close()
208
+
209
+ print(f"uploaded files: {uploaded_files}")
210
+
211
+ return {"uploaded_files": uploaded_files, "status": True}
212
+
213
+ def enhance_audio(self):
214
+ try:
215
+ # Enhance audio
216
+ self._enhance()
217
+ # Obtain list of file paths for enhanced audio
218
+ wav_files = glob.glob(os.path.join(self.enhanced_audio_path, "*.wav"))
219
+ # Extract just the file names
220
+ enhanced_files = [os.path.basename(file) for file in wav_files]
221
+ return {"status": True}
222
+
223
+ except Exception as e:
224
+ print(f"Exception occured during enhancement: {e}")
225
+ raise fastapi.HTTPException(
226
+ status_code=500,
227
+ detail="An error occurred while enhancing the noisy files.",
228
+ )
229
+
230
+ def download_enhanced(self):
231
+ try:
232
+ zip_buffer = io.BytesIO()
233
+
234
+ with zipfile.ZipFile(zip_buffer, "w") as zip_file:
235
+ for wav_file in glob.glob(
236
+ os.path.join(self.enhanced_audio_path, "*.wav")
237
+ ):
238
+ zip_file.write(wav_file, arcname=os.path.basename(wav_file))
239
+ zip_buffer.seek(0)
240
+
241
+ return fastapi.responses.StreamingResponse(
242
+ iter([zip_buffer.getvalue()]), # Stream the in-memory content
243
+ media_type="application/zip",
244
+ headers={
245
+ "Content-Disposition": "attachment; filename=enhanced_audio_files.zip"
246
+ },
247
+ )
248
+
249
+ except Exception as e:
250
+ logging.error(f"Error during enhanced files download: {e}")
251
+ raise fastapi.HTTPException(
252
+ status_code=500,
253
+ detail=f"An error occurred while creating the download file: {str(e)}",
254
+ )
255
+
256
+ def reset(self):
257
+ """
258
+ Removes all audio files in preparation for another batch of enhancement.
259
+ """
260
+ for directory in [self.noisy_audio_path, self.enhanced_audio_path]:
261
+ if not os.path.isdir(directory):
262
+ continue
263
+
264
+ for filename in os.listdir(directory):
265
+ filepath = os.path.join(directory, filename)
266
+ if os.path.isfile(filepath):
267
+ try:
268
+ os.remove(filepath)
269
+ except Exception as e:
270
+ print(f"Error removing {filepath}: {e}")
271
+ return {
272
+ "status": False,
273
+ "noisy": os.listdir(self.noisy_audio_path),
274
+ "enhanced": os.listdir(self.enhanced_audio_path),
275
+ }
276
+ return {
277
+ "status": True,
278
+ "noisy": os.listdir(self.noisy_audio_path),
279
+ "enhanced": os.listdir(self.enhanced_audio_path),
280
+ }
281
+
282
+ def run(self):
283
+
284
+ uvicorn.run(self.app, host=self.host, port=self.port)
app/miner_32.ckpt CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:e7cdeefcd8eacb4767018f5094979e57973c142871357fe208c8cd362f3218a1
3
- size 1312970987
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:67492528f6ae5c2444093b8a0d4f58fd4dba4cf186aef30f0d0bc8bb4f9b3d24
3
+ size 1313035988