vera6 commited on
Commit
3106fc7
·
verified ·
1 Parent(s): 13db853

Upload folder using huggingface_hub

Browse files
Files changed (4) hide show
  1. .gitattributes +64 -63
  2. README.md +1 -1
  3. app/app.py +293 -284
  4. app/derev_elevenlabs_2.ckpt +2 -2
.gitattributes CHANGED
@@ -1,63 +1,64 @@
1
- *.7z filter=lfs diff=lfs merge=lfs -text
2
- *.arrow filter=lfs diff=lfs merge=lfs -text
3
- *.bin filter=lfs diff=lfs merge=lfs -text
4
- *.bz2 filter=lfs diff=lfs merge=lfs -text
5
- *.ckpt filter=lfs diff=lfs merge=lfs -text
6
- *.ftz filter=lfs diff=lfs merge=lfs -text
7
- *.gz filter=lfs diff=lfs merge=lfs -text
8
- *.h5 filter=lfs diff=lfs merge=lfs -text
9
- *.joblib filter=lfs diff=lfs merge=lfs -text
10
- *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
- *.model filter=lfs diff=lfs merge=lfs -text
13
- *.msgpack filter=lfs diff=lfs merge=lfs -text
14
- *.npy filter=lfs diff=lfs merge=lfs -text
15
- *.npz filter=lfs diff=lfs merge=lfs -text
16
- *.onnx filter=lfs diff=lfs merge=lfs -text
17
- *.ot filter=lfs diff=lfs merge=lfs -text
18
- *.parquet filter=lfs diff=lfs merge=lfs -text
19
- *.pb filter=lfs diff=lfs merge=lfs -text
20
- *.pickle filter=lfs diff=lfs merge=lfs -text
21
- *.pkl filter=lfs diff=lfs merge=lfs -text
22
- *.pt filter=lfs diff=lfs merge=lfs -text
23
- *.pth filter=lfs diff=lfs merge=lfs -text
24
- *.rar filter=lfs diff=lfs merge=lfs -text
25
- *.safetensors filter=lfs diff=lfs merge=lfs -text
26
- saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
- *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
- *.tflite filter=lfs diff=lfs merge=lfs -text
30
- *.tgz filter=lfs diff=lfs merge=lfs -text
31
- *.wasm filter=lfs diff=lfs merge=lfs -text
32
- *.xz filter=lfs diff=lfs merge=lfs -text
33
- *.zip filter=lfs diff=lfs merge=lfs -text
34
- *.zst filter=lfs diff=lfs merge=lfs -text
35
- *tfevents* filter=lfs diff=lfs merge=lfs -text
36
- .venv/Scripts/tiny-agents.exe filter=lfs diff=lfs merge=lfs -text
37
- .venv/Scripts/normalizer.exe filter=lfs diff=lfs merge=lfs -text
38
- .venv/Scripts/huggingface-cli.exe filter=lfs diff=lfs merge=lfs -text
39
- .venv/Scripts/tqdm.exe filter=lfs diff=lfs merge=lfs -text
40
- .venv/Scripts/pip.exe filter=lfs diff=lfs merge=lfs -text
41
- .venv/Lib/site-packages/charset_normalizer/md__mypyc.cp312-win_amd64.pyd filter=lfs diff=lfs merge=lfs -text
42
- .venv/Scripts/pip3.12.exe filter=lfs diff=lfs merge=lfs -text
43
- .venv/Scripts/pip3.exe filter=lfs diff=lfs merge=lfs -text
44
- .venv/Scripts/hf.exe filter=lfs diff=lfs merge=lfs -text
45
- .venv/Lib/site-packages/pip/_vendor/rich/__pycache__/console.cpython-312.pyc filter=lfs diff=lfs merge=lfs -text
46
- .venv/Lib/site-packages/prompt_toolkit/layout/__pycache__/containers.cpython-312.pyc filter=lfs diff=lfs merge=lfs -text
47
- .venv/Lib/site-packages/prompt_toolkit/key_binding/bindings/__pycache__/vi.cpython-312.pyc filter=lfs diff=lfs merge=lfs -text
48
- .venv/Lib/site-packages/pip/_vendor/distlib/t64.exe filter=lfs diff=lfs merge=lfs -text
49
- .venv/Lib/site-packages/pip/_vendor/distlib/w64.exe filter=lfs diff=lfs merge=lfs -text
50
- .venv/Scripts/python.exe filter=lfs diff=lfs merge=lfs -text
51
- .venv/Lib/site-packages/yaml/_yaml.cp312-win_amd64.pyd filter=lfs diff=lfs merge=lfs -text
52
- .venv/Lib/site-packages/pip/_vendor/idna/__pycache__/uts46data.cpython-312.pyc filter=lfs diff=lfs merge=lfs -text
53
- .venv/Lib/site-packages/pip/_vendor/pkg_resources/__pycache__/__init__.cpython-312.pyc filter=lfs diff=lfs merge=lfs -text
54
- .venv/Lib/site-packages/idna/__pycache__/uts46data.cpython-312.pyc filter=lfs diff=lfs merge=lfs -text
55
- .venv/Lib/site-packages/huggingface_hub/inference/__pycache__/_client.cpython-312.pyc filter=lfs diff=lfs merge=lfs -text
56
- .venv/Lib/site-packages/pip/_vendor/distlib/t64-arm.exe filter=lfs diff=lfs merge=lfs -text
57
- .venv/Lib/site-packages/pip/_vendor/distlib/w64-arm.exe filter=lfs diff=lfs merge=lfs -text
58
- .venv/Scripts/pythonw.exe filter=lfs diff=lfs merge=lfs -text
59
- .venv/Lib/site-packages/huggingface_hub/inference/_generated/__pycache__/_async_client.cpython-312.pyc.2109505304720 filter=lfs diff=lfs merge=lfs -text
60
- .venv/Lib/site-packages/pip/_vendor/rich/__pycache__/_emoji_codes.cpython-312.pyc filter=lfs diff=lfs merge=lfs -text
61
- .venv/Lib/site-packages/huggingface_hub/__pycache__/hf_api.cpython-312.pyc filter=lfs diff=lfs merge=lfs -text
62
- .venv/Lib/site-packages/__pycache__/typing_extensions.cpython-312.pyc filter=lfs diff=lfs merge=lfs -text
63
- .venv/Lib/site-packages/hf_transfer/hf_transfer.pyd filter=lfs diff=lfs merge=lfs -text
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ .venv/Scripts/tiny-agents.exe filter=lfs diff=lfs merge=lfs -text
37
+ .venv/Scripts/normalizer.exe filter=lfs diff=lfs merge=lfs -text
38
+ .venv/Scripts/huggingface-cli.exe filter=lfs diff=lfs merge=lfs -text
39
+ .venv/Scripts/tqdm.exe filter=lfs diff=lfs merge=lfs -text
40
+ .venv/Scripts/pip.exe filter=lfs diff=lfs merge=lfs -text
41
+ .venv/Lib/site-packages/charset_normalizer/md__mypyc.cp312-win_amd64.pyd filter=lfs diff=lfs merge=lfs -text
42
+ .venv/Scripts/pip3.12.exe filter=lfs diff=lfs merge=lfs -text
43
+ .venv/Scripts/pip3.exe filter=lfs diff=lfs merge=lfs -text
44
+ .venv/Scripts/hf.exe filter=lfs diff=lfs merge=lfs -text
45
+ .venv/Lib/site-packages/pip/_vendor/rich/__pycache__/console.cpython-312.pyc filter=lfs diff=lfs merge=lfs -text
46
+ .venv/Lib/site-packages/prompt_toolkit/layout/__pycache__/containers.cpython-312.pyc filter=lfs diff=lfs merge=lfs -text
47
+ .venv/Lib/site-packages/prompt_toolkit/key_binding/bindings/__pycache__/vi.cpython-312.pyc filter=lfs diff=lfs merge=lfs -text
48
+ .venv/Lib/site-packages/pip/_vendor/distlib/t64.exe filter=lfs diff=lfs merge=lfs -text
49
+ .venv/Lib/site-packages/pip/_vendor/distlib/w64.exe filter=lfs diff=lfs merge=lfs -text
50
+ .venv/Scripts/python.exe filter=lfs diff=lfs merge=lfs -text
51
+ .venv/Lib/site-packages/yaml/_yaml.cp312-win_amd64.pyd filter=lfs diff=lfs merge=lfs -text
52
+ .venv/Lib/site-packages/pip/_vendor/idna/__pycache__/uts46data.cpython-312.pyc filter=lfs diff=lfs merge=lfs -text
53
+ .venv/Lib/site-packages/pip/_vendor/pkg_resources/__pycache__/__init__.cpython-312.pyc filter=lfs diff=lfs merge=lfs -text
54
+ .venv/Lib/site-packages/idna/__pycache__/uts46data.cpython-312.pyc filter=lfs diff=lfs merge=lfs -text
55
+ .venv/Lib/site-packages/huggingface_hub/inference/__pycache__/_client.cpython-312.pyc filter=lfs diff=lfs merge=lfs -text
56
+ .venv/Lib/site-packages/pip/_vendor/distlib/t64-arm.exe filter=lfs diff=lfs merge=lfs -text
57
+ .venv/Lib/site-packages/pip/_vendor/distlib/w64-arm.exe filter=lfs diff=lfs merge=lfs -text
58
+ .venv/Scripts/pythonw.exe filter=lfs diff=lfs merge=lfs -text
59
+ .venv/Lib/site-packages/huggingface_hub/inference/_generated/__pycache__/_async_client.cpython-312.pyc.2109505304720 filter=lfs diff=lfs merge=lfs -text
60
+ .venv/Lib/site-packages/pip/_vendor/rich/__pycache__/_emoji_codes.cpython-312.pyc filter=lfs diff=lfs merge=lfs -text
61
+ .venv/Lib/site-packages/huggingface_hub/__pycache__/hf_api.cpython-312.pyc filter=lfs diff=lfs merge=lfs -text
62
+ .venv/Lib/site-packages/__pycache__/typing_extensions.cpython-312.pyc filter=lfs diff=lfs merge=lfs -text
63
+ .venv/Lib/site-packages/hf_transfer/hf_transfer.pyd filter=lfs diff=lfs merge=lfs -text
64
+ app/*.ckpt filter=lfs diff=lfs merge=lfs -text
README.md CHANGED
@@ -1 +1 @@
1
- DENOISING speech enhancement model
 
1
+ VERA
app/app.py CHANGED
@@ -1,284 +1,293 @@
1
- import fastapi
2
- import shutil
3
- import os
4
- import zipfile
5
- import io
6
- import uvicorn
7
- import threading
8
- import glob
9
- from typing import List
10
- import torch
11
- import gdown
12
- from soundfile import write
13
- from torchaudio import load
14
- from librosa import resample
15
- import logging
16
-
17
- logging.basicConfig(level=logging.DEBUG)
18
-
19
- from sgmse import ScoreModel
20
- from sgmse.util.other import pad_spec
21
-
22
-
23
- class ModelAPI:
24
-
25
- def __init__(self, host, port):
26
-
27
- self.host = host
28
- self.port = port
29
-
30
- self.base_path = os.path.join(os.path.expanduser("~"), ".modelapi")
31
- self.noisy_audio_path = os.path.join(self.base_path, "noisy_audio")
32
- self.enhanced_audio_path = os.path.join(self.base_path, "enhanced_audio")
33
- app_dir = os.path.dirname(os.path.abspath(__file__))
34
- self.ckpt_path = glob.glob(os.path.join(app_dir, "*.ckpt"))[0]
35
- self.device = "cuda" if torch.cuda.is_available() else "cpu"
36
- self.corrector = "ald"
37
- self.corrector_steps = 1
38
- self.snr = 0.5
39
- self.N = 30
40
-
41
- for audio_path in [self.noisy_audio_path, self.enhanced_audio_path]:
42
- if not os.path.exists(audio_path):
43
- os.makedirs(audio_path)
44
-
45
- for filename in os.listdir(audio_path):
46
- file_path = os.path.join(audio_path, filename)
47
-
48
- try:
49
- if os.path.isfile(file_path) or os.path.islink(file_path):
50
- os.unlink(file_path)
51
- elif os.path.isdir(file_path):
52
- shutil.rmtree(file_path)
53
- except Exception as e:
54
- raise e
55
-
56
- self.app = fastapi.FastAPI()
57
- self._setup_routes()
58
-
59
- def _prepare(self):
60
- """Miners should modify this function to fit their fine-tuned models.
61
-
62
- This function will make any preparations necessary to initialize the
63
- speech enhancement model (i.e. downloading checkpoint files, etc.)
64
- """
65
-
66
- self.model = ScoreModel.load_from_checkpoint(self.ckpt_path, self.device)
67
- self.model.t_eps = 0.03
68
- self.model.eval()
69
-
70
- def _enhance(self):
71
- """
72
- Miners should modify this function to fit their fine-tuned models.
73
-
74
- This function will:
75
- 1. Open each noisy .wav file
76
- 2. Enhance the audio with the model
77
- 3. Save the enhanced audio in .wav format to ModelAPI.enhanced_audio_path
78
- """
79
-
80
- if self.model.backbone == "ncsnpp_48k":
81
- target_sr = 48000
82
- pad_mode = "reflection"
83
- elif self.model.backbone == "ncsnpp_v2":
84
- target_sr = 16000
85
- pad_mode = "reflection"
86
- print("using ncsnpp_v2")
87
- else:
88
- target_sr = 16000
89
- pad_mode = "zero_pad"
90
-
91
- noisy_files = sorted(glob.glob(os.path.join(self.noisy_audio_path, "*.wav")))
92
- for noisy_file in noisy_files:
93
-
94
- filename = noisy_file.replace(self.noisy_audio_path, "")
95
- filename = filename[1:] if filename.startswith("/") else filename
96
-
97
- y, sr = load(noisy_file)
98
-
99
- if sr != target_sr:
100
- y = torch.tensor(resample(y.numpy(), orig_sr=sr, target_sr=target_sr))
101
-
102
- T_orig = y.size(1)
103
-
104
- # Normalize
105
- norm_factor = y.abs().max()
106
- y = y / norm_factor
107
-
108
- # Prepare DNN input
109
- Y = torch.unsqueeze(
110
- self.model._forward_transform(self.model._stft(y.to(self.device))), 0
111
- )
112
- Y = pad_spec(Y, mode=pad_mode)
113
-
114
- # Reverse sampling
115
- if self.model.sde.__class__.__name__ == "OUVESDE":
116
- if self.model.sde.sampler_type == "pc":
117
- sampler = self.model.get_pc_sampler(
118
- "reverse_diffusion",
119
- self.corrector,
120
- Y.to(self.device),
121
- N=self.N,
122
- corrector_steps=self.corrector_steps,
123
- snr=self.snr,
124
- )
125
- elif self.model.sde.sampler_type == "ode":
126
- sampler = self.model.get_ode_sampler(Y.to(self.device), N=self.N)
127
- else:
128
- raise ValueError(f"Sampler type {args.sampler_type} not supported")
129
- elif self.model.sde.__class__.__name__ == "SBVESDE":
130
- sampler_type = (
131
- "ode"
132
- if self.model.sde.sampler_type == "pc"
133
- else self.model.sde.sampler_type
134
- )
135
- sampler = self.model.get_sb_sampler(
136
- sde=self.model.sde, y=Y.cuda(), sampler_type=sampler_type
137
- )
138
- else:
139
- raise ValueError(
140
- f"SDE {self.model.sde.__class__.__name__} not supported"
141
- )
142
-
143
- sample, _ = sampler()
144
-
145
- x_hat = self.model.to_audio(sample.squeeze(), T_orig)
146
-
147
- x_hat = x_hat * norm_factor
148
-
149
- os.makedirs(
150
- os.path.dirname(os.path.join(self.enhanced_audio_path, filename)),
151
- exist_ok=True,
152
- )
153
- write(
154
- os.path.join(self.enhanced_audio_path, filename),
155
- x_hat.cpu().numpy(),
156
- target_sr,
157
- )
158
-
159
- def _setup_routes(self):
160
- self.app.get("/status/")(self.get_status)
161
- self.app.post("/prepare/")(self.prepare)
162
- self.app.post("/upload-audio/")(self.upload_audio)
163
- self.app.post("/enhance/")(self.enhance_audio)
164
- self.app.get("/download-enhanced/")(self.download_enhanced)
165
- self.app.post("/reset/")(self.reset)
166
-
167
- def get_status(self):
168
- try:
169
- return {"container_running": True}
170
- except Exception as e:
171
- logging.error(f"Error getting status: {e}")
172
- raise fastapi.HTTPException(
173
- status_code=500, detail="An error occurred while fetching API status."
174
- )
175
-
176
- def prepare(self):
177
- try:
178
- self._prepare()
179
- return {"preparations": True}
180
- except Exception as e:
181
- logging.error(f"Error during preparations: {e}")
182
- return fastapi.HTTPException(
183
- status_code=500, detail="An error occurred while fetching API status."
184
- )
185
-
186
- def upload_audio(self, files: List[fastapi.UploadFile] = fastapi.File(...)):
187
-
188
- uploaded_files = []
189
-
190
- for file in files:
191
- try:
192
- file_path = os.path.join(self.noisy_audio_path, file.filename)
193
-
194
- with open(file_path, "wb") as f:
195
- while contents := file.file.read(1024 * 1024):
196
- f.write(contents)
197
-
198
- uploaded_files.append(file.filename)
199
-
200
- except Exception as e:
201
- logging.error(f"Error uploading files: {e}")
202
- raise fastapi.HTTPException(
203
- status_code=500,
204
- detail="An error occurred while uploading the noisy files.",
205
- )
206
- finally:
207
- file.file.close()
208
-
209
- print(f"uploaded files: {uploaded_files}")
210
-
211
- return {"uploaded_files": uploaded_files, "status": True}
212
-
213
- def enhance_audio(self):
214
- try:
215
- # Enhance audio
216
- self._enhance()
217
- # Obtain list of file paths for enhanced audio
218
- wav_files = glob.glob(os.path.join(self.enhanced_audio_path, "*.wav"))
219
- # Extract just the file names
220
- enhanced_files = [os.path.basename(file) for file in wav_files]
221
- return {"status": True}
222
-
223
- except Exception as e:
224
- print(f"Exception occured during enhancement: {e}")
225
- raise fastapi.HTTPException(
226
- status_code=500,
227
- detail="An error occurred while enhancing the noisy files.",
228
- )
229
-
230
- def download_enhanced(self):
231
- try:
232
- zip_buffer = io.BytesIO()
233
-
234
- with zipfile.ZipFile(zip_buffer, "w") as zip_file:
235
- for wav_file in glob.glob(
236
- os.path.join(self.enhanced_audio_path, "*.wav")
237
- ):
238
- zip_file.write(wav_file, arcname=os.path.basename(wav_file))
239
- zip_buffer.seek(0)
240
-
241
- return fastapi.responses.StreamingResponse(
242
- iter([zip_buffer.getvalue()]), # Stream the in-memory content
243
- media_type="application/zip",
244
- headers={
245
- "Content-Disposition": "attachment; filename=enhanced_audio_files.zip"
246
- },
247
- )
248
-
249
- except Exception as e:
250
- logging.error(f"Error during enhanced files download: {e}")
251
- raise fastapi.HTTPException(
252
- status_code=500,
253
- detail=f"An error occurred while creating the download file: {str(e)}",
254
- )
255
-
256
- def reset(self):
257
- """
258
- Removes all audio files in preparation for another batch of enhancement.
259
- """
260
- for directory in [self.noisy_audio_path, self.enhanced_audio_path]:
261
- if not os.path.isdir(directory):
262
- continue
263
-
264
- for filename in os.listdir(directory):
265
- filepath = os.path.join(directory, filename)
266
- if os.path.isfile(filepath):
267
- try:
268
- os.remove(filepath)
269
- except Exception as e:
270
- print(f"Error removing {filepath}: {e}")
271
- return {
272
- "status": False,
273
- "noisy": os.listdir(self.noisy_audio_path),
274
- "enhanced": os.listdir(self.enhanced_audio_path),
275
- }
276
- return {
277
- "status": True,
278
- "noisy": os.listdir(self.noisy_audio_path),
279
- "enhanced": os.listdir(self.enhanced_audio_path),
280
- }
281
-
282
- def run(self):
283
-
284
- uvicorn.run(self.app, host=self.host, port=self.port)
 
 
 
 
 
 
 
 
 
 
1
+ import fastapi
2
+ import shutil
3
+ import os
4
+ import zipfile
5
+ import io
6
+ import uvicorn
7
+ import threading
8
+ import glob
9
+ from typing import List
10
+ import torch
11
+ import gdown
12
+ from soundfile import write
13
+ from torchaudio import load
14
+ from librosa import resample
15
+ import logging
16
+
17
+ logging.basicConfig(level=logging.DEBUG)
18
+
19
+ from sgmse import ScoreModel
20
+ from sgmse.util.other import pad_spec
21
+
22
+
23
+ class ModelAPI:
24
+
25
+ def __init__(self, host, port):
26
+
27
+ self.host = host
28
+ self.port = port
29
+
30
+ self.base_path = os.path.join(os.path.expanduser("~"), ".modelapi")
31
+ self.noisy_audio_path = os.path.join(self.base_path, "noisy_audio")
32
+ self.enhanced_audio_path = os.path.join(self.base_path, "enhanced_audio")
33
+ app_dir = os.path.dirname(os.path.abspath(__file__))
34
+
35
+ ckpt_files = glob.glob(os.path.join(app_dir, "*.ckpt"))
36
+
37
+ if not ckpt_files:
38
+ raise FileNotFoundError("No .ckpt file found in app_dir.")
39
+ elif len(ckpt_files) > 1:
40
+ raise RuntimeError("Multiple .ckpt files found in app_dir. Please keep only one.")
41
+ else:
42
+ self.ckpt_path = ckpt_files[0]
43
+
44
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
45
+ self.corrector = "ald"
46
+ self.corrector_steps = 1
47
+ self.snr = 0.5
48
+ self.N = 30
49
+
50
+ for audio_path in [self.noisy_audio_path, self.enhanced_audio_path]:
51
+ if not os.path.exists(audio_path):
52
+ os.makedirs(audio_path)
53
+
54
+ for filename in os.listdir(audio_path):
55
+ file_path = os.path.join(audio_path, filename)
56
+
57
+ try:
58
+ if os.path.isfile(file_path) or os.path.islink(file_path):
59
+ os.unlink(file_path)
60
+ elif os.path.isdir(file_path):
61
+ shutil.rmtree(file_path)
62
+ except Exception as e:
63
+ raise e
64
+
65
+ self.app = fastapi.FastAPI()
66
+ self._setup_routes()
67
+
68
+ def _prepare(self):
69
+ """Miners should modify this function to fit their fine-tuned models.
70
+
71
+ This function will make any preparations necessary to initialize the
72
+ speech enhancement model (i.e. downloading checkpoint files, etc.)
73
+ """
74
+
75
+ self.model = ScoreModel.load_from_checkpoint(self.ckpt_path, self.device)
76
+ self.model.t_eps = 0.03
77
+ self.model.eval()
78
+
79
+ def _enhance(self):
80
+ """
81
+ Miners should modify this function to fit their fine-tuned models.
82
+
83
+ This function will:
84
+ 1. Open each noisy .wav file
85
+ 2. Enhance the audio with the model
86
+ 3. Save the enhanced audio in .wav format to ModelAPI.enhanced_audio_path
87
+ """
88
+
89
+ if self.model.backbone == "ncsnpp_48k":
90
+ target_sr = 48000
91
+ pad_mode = "reflection"
92
+ elif self.model.backbone == "ncsnpp_v2":
93
+ target_sr = 16000
94
+ pad_mode = "reflection"
95
+ print("using ncsnpp_v2")
96
+ else:
97
+ target_sr = 16000
98
+ pad_mode = "zero_pad"
99
+
100
+ noisy_files = sorted(glob.glob(os.path.join(self.noisy_audio_path, "*.wav")))
101
+ for noisy_file in noisy_files:
102
+
103
+ filename = noisy_file.replace(self.noisy_audio_path, "")
104
+ filename = filename[1:] if filename.startswith("/") else filename
105
+
106
+ y, sr = load(noisy_file)
107
+
108
+ if sr != target_sr:
109
+ y = torch.tensor(resample(y.numpy(), orig_sr=sr, target_sr=target_sr))
110
+
111
+ T_orig = y.size(1)
112
+
113
+ # Normalize
114
+ norm_factor = y.abs().max()
115
+ y = y / norm_factor
116
+
117
+ # Prepare DNN input
118
+ Y = torch.unsqueeze(
119
+ self.model._forward_transform(self.model._stft(y.to(self.device))), 0
120
+ )
121
+ Y = pad_spec(Y, mode=pad_mode)
122
+
123
+ # Reverse sampling
124
+ if self.model.sde.__class__.__name__ == "OUVESDE":
125
+ if self.model.sde.sampler_type == "pc":
126
+ sampler = self.model.get_pc_sampler(
127
+ "reverse_diffusion",
128
+ self.corrector,
129
+ Y.to(self.device),
130
+ N=self.N,
131
+ corrector_steps=self.corrector_steps,
132
+ snr=self.snr,
133
+ )
134
+ elif self.model.sde.sampler_type == "ode":
135
+ sampler = self.model.get_ode_sampler(Y.to(self.device), N=self.N)
136
+ else:
137
+ raise ValueError(f"Sampler type {args.sampler_type} not supported")
138
+ elif self.model.sde.__class__.__name__ == "SBVESDE":
139
+ sampler_type = (
140
+ "ode"
141
+ if self.model.sde.sampler_type == "pc"
142
+ else self.model.sde.sampler_type
143
+ )
144
+ sampler = self.model.get_sb_sampler(
145
+ sde=self.model.sde, y=Y.cuda(), sampler_type=sampler_type
146
+ )
147
+ else:
148
+ raise ValueError(
149
+ f"SDE {self.model.sde.__class__.__name__} not supported"
150
+ )
151
+
152
+ sample, _ = sampler()
153
+
154
+ x_hat = self.model.to_audio(sample.squeeze(), T_orig)
155
+
156
+ x_hat = x_hat * norm_factor
157
+
158
+ os.makedirs(
159
+ os.path.dirname(os.path.join(self.enhanced_audio_path, filename)),
160
+ exist_ok=True,
161
+ )
162
+ write(
163
+ os.path.join(self.enhanced_audio_path, filename),
164
+ x_hat.cpu().numpy(),
165
+ target_sr,
166
+ )
167
+
168
+ def _setup_routes(self):
169
+ self.app.get("/status/")(self.get_status)
170
+ self.app.post("/prepare/")(self.prepare)
171
+ self.app.post("/upload-audio/")(self.upload_audio)
172
+ self.app.post("/enhance/")(self.enhance_audio)
173
+ self.app.get("/download-enhanced/")(self.download_enhanced)
174
+ self.app.post("/reset/")(self.reset)
175
+
176
+ def get_status(self):
177
+ try:
178
+ return {"container_running": True}
179
+ except Exception as e:
180
+ logging.error(f"Error getting status: {e}")
181
+ raise fastapi.HTTPException(
182
+ status_code=500, detail="An error occurred while fetching API status."
183
+ )
184
+
185
+ def prepare(self):
186
+ try:
187
+ self._prepare()
188
+ return {"preparations": True}
189
+ except Exception as e:
190
+ logging.error(f"Error during preparations: {e}")
191
+ return fastapi.HTTPException(
192
+ status_code=500, detail="An error occurred while fetching API status."
193
+ )
194
+
195
+ def upload_audio(self, files: List[fastapi.UploadFile] = fastapi.File(...)):
196
+
197
+ uploaded_files = []
198
+
199
+ for file in files:
200
+ try:
201
+ file_path = os.path.join(self.noisy_audio_path, file.filename)
202
+
203
+ with open(file_path, "wb") as f:
204
+ while contents := file.file.read(1024 * 1024):
205
+ f.write(contents)
206
+
207
+ uploaded_files.append(file.filename)
208
+
209
+ except Exception as e:
210
+ logging.error(f"Error uploading files: {e}")
211
+ raise fastapi.HTTPException(
212
+ status_code=500,
213
+ detail="An error occurred while uploading the noisy files.",
214
+ )
215
+ finally:
216
+ file.file.close()
217
+
218
+ print(f"uploaded files: {uploaded_files}")
219
+
220
+ return {"uploaded_files": uploaded_files, "status": True}
221
+
222
+ def enhance_audio(self):
223
+ try:
224
+ # Enhance audio
225
+ self._enhance()
226
+ # Obtain list of file paths for enhanced audio
227
+ wav_files = glob.glob(os.path.join(self.enhanced_audio_path, "*.wav"))
228
+ # Extract just the file names
229
+ enhanced_files = [os.path.basename(file) for file in wav_files]
230
+ return {"status": True}
231
+
232
+ except Exception as e:
233
+ print(f"Exception occured during enhancement: {e}")
234
+ raise fastapi.HTTPException(
235
+ status_code=500,
236
+ detail="An error occurred while enhancing the noisy files.",
237
+ )
238
+
239
+ def download_enhanced(self):
240
+ try:
241
+ zip_buffer = io.BytesIO()
242
+
243
+ with zipfile.ZipFile(zip_buffer, "w") as zip_file:
244
+ for wav_file in glob.glob(
245
+ os.path.join(self.enhanced_audio_path, "*.wav")
246
+ ):
247
+ zip_file.write(wav_file, arcname=os.path.basename(wav_file))
248
+ zip_buffer.seek(0)
249
+
250
+ return fastapi.responses.StreamingResponse(
251
+ iter([zip_buffer.getvalue()]), # Stream the in-memory content
252
+ media_type="application/zip",
253
+ headers={
254
+ "Content-Disposition": "attachment; filename=enhanced_audio_files.zip"
255
+ },
256
+ )
257
+
258
+ except Exception as e:
259
+ logging.error(f"Error during enhanced files download: {e}")
260
+ raise fastapi.HTTPException(
261
+ status_code=500,
262
+ detail=f"An error occurred while creating the download file: {str(e)}",
263
+ )
264
+
265
+ def reset(self):
266
+ """
267
+ Removes all audio files in preparation for another batch of enhancement.
268
+ """
269
+ for directory in [self.noisy_audio_path, self.enhanced_audio_path]:
270
+ if not os.path.isdir(directory):
271
+ continue
272
+
273
+ for filename in os.listdir(directory):
274
+ filepath = os.path.join(directory, filename)
275
+ if os.path.isfile(filepath):
276
+ try:
277
+ os.remove(filepath)
278
+ except Exception as e:
279
+ print(f"Error removing {filepath}: {e}")
280
+ return {
281
+ "status": False,
282
+ "noisy": os.listdir(self.noisy_audio_path),
283
+ "enhanced": os.listdir(self.enhanced_audio_path),
284
+ }
285
+ return {
286
+ "status": True,
287
+ "noisy": os.listdir(self.noisy_audio_path),
288
+ "enhanced": os.listdir(self.enhanced_audio_path),
289
+ }
290
+
291
+ def run(self):
292
+
293
+ uvicorn.run(self.app, host=self.host, port=self.port)
app/derev_elevenlabs_2.ckpt CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:be62f390dac0b6a443eeb062ae94e7eba6d93fdbfbc79cd7f7f8027b193b2e23
3
- size 1295837235
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:103ce161f80fffa576079282305ef039ab5ba51f74428f54b2981f5c1a2f84e6
3
+ size 1295832473