Iliass Lasri commited on
Commit
27d7586
·
1 Parent(s): 54dc2f8

added all files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. README.md +7 -6
  2. app.py +285 -0
  3. fspen/.gitignore +158 -0
  4. fspen/.pre-commit-config.yaml +134 -0
  5. fspen/.project-root +2 -0
  6. fspen/Makefile +30 -0
  7. fspen/README.md +94 -0
  8. fspen/configs/__init__.py +1 -0
  9. fspen/configs/callbacks/default.yaml +23 -0
  10. fspen/configs/callbacks/early_stopping.yaml +15 -0
  11. fspen/configs/callbacks/model_checkpoint.yaml +17 -0
  12. fspen/configs/callbacks/model_summary.yaml +5 -0
  13. fspen/configs/callbacks/none.yaml +0 -0
  14. fspen/configs/callbacks/rich_progress_bar.yaml +4 -0
  15. fspen/configs/data/speech_enhancement.yaml +13 -0
  16. fspen/configs/debug/default.yaml +35 -0
  17. fspen/configs/debug/fdr.yaml +9 -0
  18. fspen/configs/debug/limit.yaml +12 -0
  19. fspen/configs/debug/overfit.yaml +13 -0
  20. fspen/configs/debug/profiler.yaml +12 -0
  21. fspen/configs/eval.yaml +19 -0
  22. fspen/configs/experiment/example.yaml +41 -0
  23. fspen/configs/extras/default.yaml +8 -0
  24. fspen/configs/hparams_search/mnist_optuna.yaml +52 -0
  25. fspen/configs/hydra/default.yaml +19 -0
  26. fspen/configs/local/.gitkeep +0 -0
  27. fspen/configs/logger/aim.yaml +28 -0
  28. fspen/configs/logger/comet.yaml +12 -0
  29. fspen/configs/logger/csv.yaml +7 -0
  30. fspen/configs/logger/many_loggers.yaml +9 -0
  31. fspen/configs/logger/mlflow.yaml +12 -0
  32. fspen/configs/logger/neptune.yaml +9 -0
  33. fspen/configs/logger/tensorboard.yaml +10 -0
  34. fspen/configs/logger/wandb.yaml +16 -0
  35. fspen/configs/model/fspen.yaml +24 -0
  36. fspen/configs/paths/default.yaml +19 -0
  37. fspen/configs/paths/eval.yaml +19 -0
  38. fspen/configs/train.yaml +49 -0
  39. fspen/configs/trainer/cpu.yaml +5 -0
  40. fspen/configs/trainer/ddp.yaml +9 -0
  41. fspen/configs/trainer/ddp_sim.yaml +7 -0
  42. fspen/configs/trainer/default.yaml +19 -0
  43. fspen/configs/trainer/gpu.yaml +5 -0
  44. fspen/configs/trainer/mps.yaml +5 -0
  45. fspen/environment.yaml +125 -0
  46. fspen/notebooks/.gitkeep +0 -0
  47. fspen/pyproject.toml +25 -0
  48. fspen/requirements.txt +24 -0
  49. fspen/scripts/schedule.sh +7 -0
  50. fspen/setup.py +21 -0
README.md CHANGED
@@ -1,12 +1,13 @@
1
  ---
2
- title: Fspen
3
- emoji: 🔥
4
- colorFrom: yellow
5
- colorTo: pink
6
  sdk: gradio
7
- sdk_version: 6.2.0
8
  app_file: app.py
 
9
  pinned: false
 
10
  ---
11
 
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: DeepFilterNet2
3
+ emoji: 💩
4
+ colorFrom: gray
5
+ colorTo: red
6
  sdk: gradio
 
7
  app_file: app.py
8
+ sdk_version: 3.17.1
9
  pinned: false
10
+ license: apache-2.0
11
  ---
12
 
13
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces#reference
app.py ADDED
@@ -0,0 +1,285 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import glob
2
+ import math
3
+ import os
4
+ import sys
5
+ import tempfile
6
+ import time
7
+ from typing import List, Optional, Tuple, Union
8
+
9
+ import gradio as gr
10
+ import matplotlib.pyplot as plt
11
+ import numpy as np
12
+ import torch
13
+ import torchaudio
14
+ import torchaudio.transforms as T
15
+ from loguru import logger
16
+ from PIL import Image
17
+
18
+ sys.path.append("fspen")
19
+ from fspen.src.test import enhance_audio
20
+
21
+ CHECKPOINT_PATH = ""
22
+ TARGET_SR = 16000
23
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
24
+
25
+ # --- PLOTTING SETUP ---
26
+ fig_noisy: plt.Figure
27
+ fig_enh: plt.Figure
28
+ ax_noisy: plt.Axes
29
+ ax_enh: plt.Axes
30
+ fig_noisy, ax_noisy = plt.subplots(figsize=(15.2, 4))
31
+ fig_noisy.set_tight_layout(True)
32
+ fig_enh, ax_enh = plt.subplots(figsize=(15.2, 4))
33
+ fig_enh.set_tight_layout(True)
34
+
35
+ NOISES = {
36
+ "None": None,
37
+ "Kitchen": "samples/dkitchen.wav",
38
+ "Living Room": "samples/dliving.wav",
39
+ "River": "samples/nriver.wav",
40
+ "Cafe": "samples/scafe.wav",
41
+ }
42
+
43
+ # --- HELPER FUNCTIONS ---
44
+
45
+
46
+ def load_audio_torch(path, target_sr=TARGET_SR):
47
+ """Replacement for df.load_audio using torchaudio"""
48
+ if path is None:
49
+ return None, None
50
+
51
+ sig, sr = torchaudio.load(path)
52
+ if sr != target_sr:
53
+ resampler = T.Resample(sr, target_sr)
54
+ sig = resampler(sig)
55
+ return sig, target_sr
56
+
57
+
58
+ def save_audio_torch(path, tensor, sr):
59
+ """Replacement for df.save_audio using torchaudio"""
60
+ # Ensure tensor is on CPU
61
+ tensor = tensor.detach().cpu()
62
+ # Check shape [channels, time], torchaudio expects this
63
+ if tensor.dim() == 1:
64
+ tensor = tensor.unsqueeze(0)
65
+ torchaudio.save(path, tensor, sr)
66
+
67
+
68
+ def mix_at_snr(clean, noise, snr, eps=1e-10):
69
+ """Mix clean and noise signal at a given SNR."""
70
+ # Standardize to (1, T)
71
+ if clean.dim() == 1:
72
+ clean = clean.unsqueeze(0)
73
+ if noise.dim() == 1:
74
+ noise = noise.unsqueeze(0)
75
+
76
+ clean = clean.mean(0, keepdim=True)
77
+ noise = noise.mean(0, keepdim=True)
78
+
79
+ if noise.shape[1] < clean.shape[1]:
80
+ noise = noise.repeat((1, int(math.ceil(clean.shape[1] / noise.shape[1]))))
81
+ max_start = int(noise.shape[1] - clean.shape[1])
82
+ start = torch.randint(0, max_start, ()).item() if max_start > 0 else 0
83
+ noise = noise[:, start : start + clean.shape[1]]
84
+
85
+ E_speech = torch.mean(clean.pow(2)) + eps
86
+ E_noise = torch.mean(noise.pow(2))
87
+ K = torch.sqrt((E_noise / E_speech) * 10 ** (snr / 10) + eps)
88
+ noise = noise / K
89
+ mixture = clean + noise
90
+
91
+ max_m = mixture.abs().max()
92
+ if max_m > 1:
93
+ clean, noise, mixture = clean / max_m, noise / max_m, mixture / max_m
94
+ return clean, noise, mixture
95
+
96
+
97
+ def specshow(
98
+ spec,
99
+ ax=None,
100
+ title=None,
101
+ sr=48000,
102
+ n_fft=None,
103
+ hop=None,
104
+ t=None,
105
+ f=None,
106
+ vmin=-100,
107
+ vmax=0,
108
+ cmap="inferno",
109
+ ):
110
+ """Plots a spectrogram of shape [F, T]"""
111
+ spec_np = spec.cpu().numpy() if isinstance(spec, torch.Tensor) else spec
112
+ if ax is None:
113
+ ax = plt
114
+
115
+ if n_fft is None:
116
+ n_fft = (spec.shape[0] - 1) * 2
117
+ hop = hop or n_fft // 4
118
+
119
+ if t is None:
120
+ t = np.arange(0, spec_np.shape[-1]) * hop / sr
121
+ if f is None:
122
+ f = np.arange(0, spec_np.shape[0]) * sr // 2 / (n_fft // 2) / 1000
123
+
124
+ im = ax.pcolormesh(
125
+ t, f, spec_np, rasterized=True, shading="auto", vmin=vmin, vmax=vmax, cmap=cmap
126
+ )
127
+ if title:
128
+ ax.set_title(title)
129
+ ax.set_xlabel("Time [s]")
130
+ ax.set_ylabel("Frequency [kHz]")
131
+ return im
132
+
133
+
134
+ def spec_im(audio: torch.Tensor, sr=TARGET_SR, figsize=(15, 5), figure=None, ax=None) -> Image:
135
+ audio = torch.as_tensor(audio)
136
+ if audio.dim() > 1:
137
+ audio = audio.mean(dim=0) # Mix to mono for spec
138
+
139
+ n_fft = 1024
140
+ hop = 512
141
+ w = torch.hann_window(n_fft, device=audio.device)
142
+ spec = torch.stft(audio, n_fft, hop, window=w, return_complex=False)
143
+ spec = spec.div_(w.pow(2).sum())
144
+ spec = torch.view_as_complex(spec).abs().clamp_min(1e-12).log10().mul(10)
145
+
146
+ if figure is None:
147
+ figure = plt.figure(figsize=figsize)
148
+ figure.set_tight_layout(True)
149
+
150
+ if spec.dim() > 2:
151
+ spec = spec.squeeze(0)
152
+ specshow(spec, ax=ax, sr=sr, n_fft=n_fft, hop=hop)
153
+
154
+ figure.canvas.draw()
155
+ return Image.frombytes("RGB", figure.canvas.get_width_height(), figure.canvas.tostring_rgb())
156
+
157
+
158
+ def cleanup_tmp(filter_list: List[str] = [], hours_keep=2):
159
+ # Basic cleanup logic
160
+ if os.path.exists("/tmp"):
161
+ for f in glob.glob("/tmp/*wav"):
162
+ # Only delete if very old or explicitly temp
163
+ pass
164
+
165
+
166
+ # --- MAIN DEMO FUNCTION ---
167
+
168
+
169
+ def demo_fn(speech_upl: str, noise_type: str, snr: int, mic_input: Optional[str] = None):
170
+ if mic_input:
171
+ speech_upl = mic_input
172
+
173
+ sr = TARGET_SR
174
+ logger.info(f"Params: speech={speech_upl}, noise={noise_type}, snr={snr}")
175
+ snr = int(snr)
176
+ noise_fn = NOISES[noise_type]
177
+
178
+ # 1. Load Clean Speech
179
+ max_s = 10
180
+ if speech_upl is not None:
181
+ sample, _ = load_audio_torch(speech_upl, sr)
182
+ max_len = max_s * sr
183
+ if sample.shape[-1] > max_len:
184
+ start = torch.randint(0, sample.shape[-1] - max_len, ()).item()
185
+ sample = sample[..., start : start + max_len]
186
+ else:
187
+ # Fallback sample
188
+ sample, _ = load_audio_torch("samples/p232_013_clean.wav", sr)
189
+ sample = sample[..., : max_s * sr]
190
+
191
+ # Ensure channels first
192
+ if sample.dim() > 1 and sample.shape[0] > 1:
193
+ sample = sample.mean(dim=0, keepdim=True)
194
+
195
+ # 2. Add Noise (if selected)
196
+ if noise_fn is not None:
197
+ noise, _ = load_audio_torch(noise_fn, sr)
198
+ _, _, sample = mix_at_snr(sample, noise, snr)
199
+
200
+ # 3. Save Noisy File (Input for enhance_audio)
201
+ noisy_wav_path = tempfile.NamedTemporaryFile(suffix="noisy.wav", delete=False).name
202
+ save_audio_torch(noisy_wav_path, sample, sr)
203
+
204
+ # 4. Run Inference using your Custom Function
205
+ enhanced_wav_path = tempfile.NamedTemporaryFile(suffix="enhanced.wav", delete=False).name
206
+
207
+ logger.info("Starting enhancement...")
208
+ # CALLING YOUR MODEL HERE
209
+ enhance_audio(CHECKPOINT_PATH, noisy_wav_path, enhanced_wav_path)
210
+ logger.info("Enhancement finished")
211
+
212
+ # 5. Load Enhanced Audio for Visualization
213
+ enhanced, _ = load_audio_torch(enhanced_wav_path, sr)
214
+
215
+ # 6. Generate Visuals
216
+ ax_noisy.clear()
217
+ ax_enh.clear()
218
+ noisy_im = spec_im(sample, sr=sr, figure=fig_noisy, ax=ax_noisy)
219
+ enh_im = spec_im(enhanced, sr=sr, figure=fig_enh, ax=ax_enh)
220
+
221
+ return noisy_wav_path, noisy_im, enhanced_wav_path, enh_im
222
+
223
+
224
+ def toggle(choice):
225
+ if choice == "mic":
226
+ return gr.update(visible=True, value=None), gr.update(visible=False, value=None)
227
+ else:
228
+ return gr.update(visible=False, value=None), gr.update(visible=True, value=None)
229
+
230
+
231
+ # --- GRADIO INTERFACE ---
232
+
233
+ with gr.Blocks() as demo:
234
+ with gr.Row():
235
+ gr.Markdown(
236
+ """
237
+ ## Audio Enhancement Demo (Custom Model)
238
+ Upload audio or record from mic to test the model in `mva-proj`.
239
+ """
240
+ )
241
+ with gr.Row():
242
+ with gr.Column():
243
+ radio = gr.Radio(["mic", "file"], value="file", label="Audio Source")
244
+ mic_input = gr.Mic(label="Microphone Input", type="filepath", visible=False)
245
+ audio_file = gr.Audio(type="filepath", label="File Input", visible=True)
246
+ inputs = [
247
+ audio_file,
248
+ gr.Dropdown(
249
+ label="Add background noise",
250
+ choices=list(NOISES.keys()),
251
+ value="None",
252
+ ),
253
+ gr.Dropdown(
254
+ label="Noise Level (SNR)",
255
+ choices=["-5", "0", "10", "20"],
256
+ value="10",
257
+ ),
258
+ mic_input,
259
+ ]
260
+ btn = gr.Button("Denoise", variant="primary")
261
+ with gr.Column():
262
+ outputs = [
263
+ gr.Audio(type="filepath", label="Noisy Input"),
264
+ gr.Image(label="Noisy Spectrogram"),
265
+ gr.Audio(type="filepath", label="Enhanced Output"),
266
+ gr.Image(label="Enhanced Spectrogram"),
267
+ ]
268
+
269
+ btn.click(fn=demo_fn, inputs=inputs, outputs=outputs, api_name="denoise")
270
+ radio.change(toggle, radio, [mic_input, audio_file])
271
+
272
+ # Examples (Ensure these files exist in your folder)
273
+ if os.path.exists("samples/p232_013_clean.wav"):
274
+ gr.Examples(
275
+ [
276
+ ["samples/p232_013_clean.wav", "Kitchen", "10"],
277
+ ["samples/p232_013_clean.wav", "Cafe", "10"],
278
+ ],
279
+ fn=demo_fn,
280
+ inputs=inputs,
281
+ outputs=outputs,
282
+ cache_examples=False, # Disable cache if model changes frequently
283
+ )
284
+
285
+ demo.launch(enable_queue=True)
fspen/.gitignore ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ ./data/
13
+ develop-eggs/
14
+ dist/
15
+ downloads/
16
+ eggs/
17
+ .eggs/
18
+ lib/
19
+ lib64/
20
+ parts/
21
+ sdist/
22
+ var/
23
+ wheels/
24
+ pip-wheel-metadata/
25
+ share/python-wheels/
26
+ *.egg-info/
27
+ .installed.cfg
28
+ *.egg
29
+ MANIFEST
30
+
31
+ # PyInstaller
32
+ # Usually these files are written by a python script from a template
33
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
34
+ *.manifest
35
+ *.spec
36
+
37
+ # Installer logs
38
+ pip-log.txt
39
+ pip-delete-this-directory.txt
40
+
41
+ # Unit test / coverage reports
42
+ htmlcov/
43
+ .tox/
44
+ .nox/
45
+ .coverage
46
+ .coverage.*
47
+ .cache
48
+ nosetests.xml
49
+ coverage.xml
50
+ *.cover
51
+ *.py,cover
52
+ .hypothesis/
53
+ .pytest_cache/
54
+
55
+ # Translations
56
+ *.mo
57
+ *.pot
58
+
59
+ # Django stuff:
60
+ *.log
61
+ local_settings.py
62
+ db.sqlite3
63
+ db.sqlite3-journal
64
+
65
+ # Flask stuff:
66
+ instance/
67
+ .webassets-cache
68
+
69
+ # Scrapy stuff:
70
+ .scrapy
71
+
72
+ # Sphinx documentation
73
+ docs/_build/
74
+
75
+ # PyBuilder
76
+ target/
77
+
78
+ # Jupyter Notebook
79
+ .ipynb_checkpoints
80
+
81
+ # IPython
82
+ profile_default/
83
+ ipython_config.py
84
+
85
+ # pyenv
86
+ .python-version
87
+
88
+ # pipenv
89
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
90
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
91
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
92
+ # install all needed dependencies.
93
+ #Pipfile.lock
94
+
95
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow
96
+ __pypackages__/
97
+
98
+ # Celery stuff
99
+ celerybeat-schedule
100
+ celerybeat.pid
101
+
102
+ # SageMath parsed files
103
+ *.sage.py
104
+
105
+ # Environments
106
+ .venv
107
+ env/
108
+ venv/
109
+ ENV/
110
+ env.bak/
111
+ venv.bak/
112
+
113
+ # Spyder project settings
114
+ .spyderproject
115
+ .spyproject
116
+
117
+ # Rope project settings
118
+ .ropeproject
119
+
120
+ # mkdocs documentation
121
+ /site
122
+
123
+ # mypy
124
+ .mypy_cache/
125
+ .dmypy.json
126
+ dmypy.json
127
+
128
+ # Pyre type checker
129
+ .pyre/
130
+
131
+ ### VisualStudioCode
132
+ .vscode/*
133
+ !.vscode/settings.json
134
+ !.vscode/tasks.json
135
+ !.vscode/launch.json
136
+ !.vscode/extensions.json
137
+ *.code-workspace
138
+ **/.vscode
139
+
140
+ # JetBrains
141
+ .idea/
142
+
143
+ # Data & Models
144
+ *.h5
145
+ *.tar
146
+ *.tar.gz
147
+
148
+ # Lightning-Hydra-Template
149
+ configs/local/default.yaml
150
+ /data/
151
+ /logs/
152
+ .env
153
+
154
+ # Aim logging
155
+ .aim
156
+ Fspen_an_Ultra-Lightweight_Network_for_Real_Time_Speech_Enahncment.pdf
157
+ voicebank_data
158
+ voicebank_wavs
fspen/.pre-commit-config.yaml ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ default_language_version:
2
+ python: python3
3
+
4
+ repos:
5
+ - repo: https://github.com/pre-commit/pre-commit-hooks
6
+ rev: v6.0.0
7
+ hooks:
8
+ # list of supported hooks: https://pre-commit.com/hooks.html
9
+ - id: trailing-whitespace
10
+ - id: end-of-file-fixer
11
+ - id: check-docstring-first
12
+ - id: check-yaml
13
+ - id: debug-statements
14
+ - id: detect-private-key
15
+ - id: check-executables-have-shebangs
16
+ - id: check-toml
17
+ - id: check-case-conflict
18
+ - id: check-added-large-files
19
+
20
+ # python code formatting
21
+ - repo: https://github.com/psf/black
22
+ rev: 25.12.0
23
+ hooks:
24
+ - id: black
25
+ args: [--line-length, "99"]
26
+
27
+ # python import sorting
28
+ - repo: https://github.com/PyCQA/isort
29
+ rev: 7.0.0
30
+ hooks:
31
+ - id: isort
32
+ args: ["--profile", "black", "--filter-files"]
33
+
34
+ # python upgrading syntax to newer version
35
+ - repo: https://github.com/asottile/pyupgrade
36
+ rev: v3.21.2
37
+ hooks:
38
+ - id: pyupgrade
39
+ args: [--py38-plus]
40
+
41
+ # python docstring formatting
42
+ - repo: https://github.com/myint/docformatter
43
+ rev: v1.7.7
44
+ hooks:
45
+ - id: docformatter
46
+ args:
47
+ [
48
+ --in-place,
49
+ --wrap-summaries=99,
50
+ --wrap-descriptions=99,
51
+ --style=sphinx,
52
+ --black,
53
+ ]
54
+
55
+ # # python docstring coverage checking
56
+ # - repo: https://github.com/econchick/interrogate
57
+ # rev: 1.7.0 # or master if you're bold
58
+ # hooks:
59
+ # - id: interrogate
60
+ # args:
61
+ # [
62
+ # --verbose,
63
+ # --fail-under=80,
64
+ # --ignore-init-module,
65
+ # --ignore-init-method,
66
+ # --ignore-module,
67
+ # --ignore-nested-functions,
68
+ # -vv,
69
+ # ]
70
+
71
+ # python check (PEP8), programming errors and code complexity
72
+ - repo: https://github.com/PyCQA/flake8
73
+ rev: 7.3.0
74
+ hooks:
75
+ - id: flake8
76
+ args:
77
+ [
78
+ "--extend-ignore",
79
+ "E203,E402,E501,F401,F841,RST2,RST301",
80
+ "--exclude",
81
+ "logs/*,data/*",
82
+ ]
83
+ additional_dependencies: [flake8-rst-docstrings==0.3.0]
84
+
85
+ # python security linter
86
+ - repo: https://github.com/PyCQA/bandit
87
+ rev: "1.9.2"
88
+ hooks:
89
+ - id: bandit
90
+ args: ["-s", "B101,B311"]
91
+
92
+ # yaml formatting
93
+ - repo: https://github.com/pre-commit/mirrors-prettier
94
+ rev: v4.0.0-alpha.8
95
+ hooks:
96
+ - id: prettier
97
+ types: [yaml]
98
+ exclude: "environment.yaml"
99
+
100
+ # shell scripts linter
101
+ - repo: https://github.com/shellcheck-py/shellcheck-py
102
+ rev: v0.11.0.1
103
+ hooks:
104
+ - id: shellcheck
105
+
106
+ # word spelling linter
107
+ - repo: https://github.com/codespell-project/codespell
108
+ rev: v2.4.1
109
+ hooks:
110
+ - id: codespell
111
+ args:
112
+ - --skip=logs/**,data/**,*.ipynb
113
+ # - --ignore-words-list=abc,def
114
+
115
+ # jupyter notebook cell output clearing
116
+ - repo: https://github.com/kynan/nbstripout
117
+ rev: 0.8.2
118
+ hooks:
119
+ - id: nbstripout
120
+
121
+ # jupyter notebook linting
122
+ - repo: https://github.com/nbQA-dev/nbQA
123
+ rev: 1.9.1
124
+ hooks:
125
+ - id: nbqa-black
126
+ args: ["--line-length=99"]
127
+ - id: nbqa-isort
128
+ args: ["--profile=black"]
129
+ - id: nbqa-flake8
130
+ args:
131
+ [
132
+ "--extend-ignore=E203,E402,E501,F401,F841",
133
+ "--exclude=logs/*,data/*",
134
+ ]
fspen/.project-root ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ # this file is required for inferring the project root directory
2
+ # do not delete
fspen/Makefile ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ help: ## Show help
3
+ @grep -E '^[.a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-30s\033[0m %s\n", $$1, $$2}'
4
+
5
+ clean: ## Clean autogenerated files
6
+ rm -rf dist
7
+ find . -type f -name "*.DS_Store" -ls -delete
8
+ find . | grep -E "(__pycache__|\.pyc|\.pyo)" | xargs rm -rf
9
+ find . | grep -E ".pytest_cache" | xargs rm -rf
10
+ find . | grep -E ".ipynb_checkpoints" | xargs rm -rf
11
+ rm -f .coverage
12
+
13
+ clean-logs: ## Clean logs
14
+ rm -rf logs/**
15
+
16
+ format: ## Run pre-commit hooks
17
+ pre-commit run -a
18
+
19
+ sync: ## Merge changes from main branch to your current branch
20
+ git pull
21
+ git pull origin main
22
+
23
+ test: ## Run not slow tests
24
+ pytest -k "not slow"
25
+
26
+ test-full: ## Run all tests
27
+ pytest
28
+
29
+ train: ## Train the model
30
+ python src/train.py
fspen/README.md ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <div align="center">
2
+
3
+ # Your Project Name
4
+
5
+ <a href="https://pytorch.org/get-started/locally/"><img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-ee4c2c?logo=pytorch&logoColor=white"></a>
6
+ <a href="https://pytorchlightning.ai/"><img alt="Lightning" src="https://img.shields.io/badge/-Lightning-792ee5?logo=pytorchlightning&logoColor=white"></a>
7
+ <a href="https://hydra.cc/"><img alt="Config: Hydra" src="https://img.shields.io/badge/Config-Hydra-89b8cd"></a>
8
+ <a href="https://github.com/ashleve/lightning-hydra-template"><img alt="Template" src="https://img.shields.io/badge/-Lightning--Hydra--Template-017F2F?style=flat&logo=github&labelColor=gray"></a><br>
9
+ [![Paper](http://img.shields.io/badge/paper-arxiv.1001.2234-B31B1B.svg)](https://ieeexplore.ieee.org/stamp/stamp.jsp?tp=&arnumber=10446016)
10
+
11
+ </div>
12
+
13
+ ## Description
14
+
15
+ What it does
16
+
17
+ ## Installation
18
+
19
+ #### Pip
20
+
21
+ ```bash
22
+ # clone project
23
+ git clone https://github.com/iliasslasri/
24
+ cd
25
+
26
+ # [OPTIONAL] create conda environment
27
+ conda create -n myenv python=3.9
28
+ conda activate myenv
29
+
30
+ # install pytorch according to instructions
31
+ # https://pytorch.org/get-started/
32
+
33
+ # install requirements
34
+ pip install -r requirements.txt
35
+ ```
36
+
37
+ #### Conda
38
+
39
+ ```bash
40
+ # clone project
41
+ git clone https://github.com/YourGithubName/your-repo-name
42
+ cd your-repo-name
43
+
44
+ # create conda environment and install dependencies
45
+ conda env create -f environment.yaml -n myenv
46
+
47
+ # activate conda environment
48
+ conda activate myenv
49
+ ```
50
+
51
+ ## How to run
52
+
53
+ Train model with default configuration
54
+
55
+ ```bash
56
+ # train on CPU
57
+ python src/train.py trainer=cpu
58
+
59
+ # train on GPU
60
+ python src/train.py trainer=gpu
61
+ ```
62
+
63
+ Train model with chosen experiment configuration from [configs/experiment/](configs/experiment/)
64
+
65
+ ```bash
66
+ python src/train.py experiment=experiment_name.yaml
67
+ ```
68
+
69
+ You can override any parameter from command line like this
70
+
71
+ ```bash
72
+ python src/train.py trainer.max_epochs=20 data.batch_size=64
73
+ ```
74
+
75
+
76
+
77
+
78
+ ## Set up environment for dev
79
+ ```bash
80
+ pre-commit install
81
+
82
+ pip install \
83
+ "torch==2.0.1+cu118" \
84
+ "torchvision==0.15.2+cu118" \
85
+ "torchaudio==2.0.2+cu118" \
86
+ "lightning==2.0.9" \
87
+ "torchmetrics==0.11.4" \
88
+ "numpy<2.0" \
89
+ "pesq" \
90
+ "hydra-colorlog" \
91
+ --extra-index-url https://download.pytorch.org/whl/cu118
92
+
93
+ python3 src/train.py callbacks.rich_progress_bar=null
94
+ ```
fspen/configs/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # this file is needed here to include configs when building project as a package
fspen/configs/callbacks/default.yaml ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ defaults:
2
+ - model_checkpoint
3
+ - early_stopping
4
+ - model_summary
5
+ - rich_progress_bar
6
+ - _self_
7
+
8
+ model_checkpoint:
9
+ dirpath: ${paths.output_dir}/checkpoints
10
+ filename: "epoch_{epoch:03d}"
11
+ monitor: "val/loss"
12
+ mode: "min"
13
+ save_last: True
14
+ auto_insert_metric_name: False
15
+ save_top_k: 3
16
+
17
+ early_stopping:
18
+ monitor: "val/loss"
19
+ patience: 100
20
+ mode: "max"
21
+
22
+ model_summary:
23
+ max_depth: -1
fspen/configs/callbacks/early_stopping.yaml ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.EarlyStopping.html
2
+
3
+ early_stopping:
4
+ _target_: lightning.pytorch.callbacks.EarlyStopping
5
+ monitor: ??? # quantity to be monitored, must be specified !!!
6
+ min_delta: 0. # minimum change in the monitored quantity to qualify as an improvement
7
+ patience: 3 # number of checks with no improvement after which training will be stopped
8
+ verbose: False # verbosity mode
9
+ mode: "min" # "max" means higher metric value is better, can be also "min"
10
+ strict: True # whether to crash the training if monitor is not found in the validation metrics
11
+ check_finite: True # when set True, stops training when the monitor becomes NaN or infinite
12
+ stopping_threshold: null # stop training immediately once the monitored quantity reaches this threshold
13
+ divergence_threshold: null # stop training as soon as the monitored quantity becomes worse than this threshold
14
+ check_on_train_epoch_end: null # whether to run early stopping at the end of the training epoch
15
+ # log_rank_zero_only: False # this keyword argument isn't available in stable version
fspen/configs/callbacks/model_checkpoint.yaml ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.ModelCheckpoint.html
2
+
3
+ model_checkpoint:
4
+ _target_: lightning.pytorch.callbacks.ModelCheckpoint
5
+ dirpath: null # directory to save the model file
6
+ filename: null # checkpoint filename
7
+ monitor: null # name of the logged metric which determines when model is improving
8
+ verbose: False # verbosity mode
9
+ save_last: null # additionally always save an exact copy of the last checkpoint to a file last.ckpt
10
+ save_top_k: 1 # save k best models (determined by above metric)
11
+ mode: "min" # "max" means higher metric value is better, can be also "min"
12
+ auto_insert_metric_name: True # when True, the checkpoints filenames will contain the metric name
13
+ save_weights_only: False # if True, then only the model’s weights will be saved
14
+ every_n_train_steps: null # number of training steps between checkpoints
15
+ train_time_interval: null # checkpoints are monitored at the specified time interval
16
+ every_n_epochs: null # number of epochs between checkpoints
17
+ save_on_train_epoch_end: null # whether to run checkpointing at the end of the training epoch or the end of validation
fspen/configs/callbacks/model_summary.yaml ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.RichModelSummary.html
2
+
3
+ model_summary:
4
+ _target_: lightning.pytorch.callbacks.RichModelSummary
5
+ max_depth: 1 # the maximum depth of layer nesting that the summary will include
fspen/configs/callbacks/none.yaml ADDED
File without changes
fspen/configs/callbacks/rich_progress_bar.yaml ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ # https://lightning.ai/docs/pytorch/latest/api/lightning.pytorch.callbacks.RichProgressBar.html
2
+
3
+ rich_progress_bar:
4
+ _target_: lightning.pytorch.callbacks.RichProgressBar
fspen/configs/data/speech_enhancement.yaml ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _target_: src.data.datamodule.SpeechEnhancementDataModule
2
+ dataset:
3
+ _target_: src.data.dataset.SpeechEnhancementDataset
4
+ sample_rate: 16000
5
+ segment_len: 10.0 # in seconds
6
+ n_fft: 512
7
+ hop_length: 128
8
+ win_length: 512
9
+ noisy_dir: ${paths.noisy_dir}
10
+ clean_dir: ${paths.clean_dir}
11
+ batch_size: 64 # Needs to be divisible by the number of devices (e.g., if in a distributed setup)
12
+ val_split: 0.1
13
+ num_workers: 2
fspen/configs/debug/default.yaml ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ # default debugging setup, runs 1 full epoch
4
+ # other debugging configs can inherit from this one
5
+
6
+ # overwrite task name so debugging logs are stored in separate folder
7
+ task_name: "debug"
8
+
9
+ # disable callbacks and loggers during debugging
10
+ callbacks: null
11
+ logger: null
12
+
13
+ extras:
14
+ ignore_warnings: False
15
+ enforce_tags: False
16
+
17
+ # sets level of all command line loggers to 'DEBUG'
18
+ # https://hydra.cc/docs/tutorials/basic/running_your_app/logging/
19
+ hydra:
20
+ job_logging:
21
+ root:
22
+ level: DEBUG
23
+
24
+ # use this to also set hydra loggers to 'DEBUG'
25
+ # verbose: True
26
+
27
+ trainer:
28
+ max_epochs: 1
29
+ accelerator: cpu # debuggers don't like gpus
30
+ devices: 1 # debuggers don't like multiprocessing
31
+ detect_anomaly: true # raise exception if NaN or +/-inf is detected in any tensor
32
+
33
+ data:
34
+ num_workers: 0 # debuggers don't like multiprocessing
35
+ pin_memory: False # disable gpu memory pin
fspen/configs/debug/fdr.yaml ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ # runs 1 train, 1 validation and 1 test step
4
+
5
+ defaults:
6
+ - default
7
+
8
+ trainer:
9
+ fast_dev_run: true
fspen/configs/debug/limit.yaml ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ # uses only 1% of the training data and 5% of validation/test data
4
+
5
+ defaults:
6
+ - default
7
+
8
+ trainer:
9
+ max_epochs: 3
10
+ limit_train_batches: 0.01
11
+ limit_val_batches: 0.05
12
+ limit_test_batches: 0.05
fspen/configs/debug/overfit.yaml ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ # overfits to 3 batches
4
+
5
+ defaults:
6
+ - default
7
+
8
+ trainer:
9
+ max_epochs: 20
10
+ overfit_batches: 3
11
+
12
+ # model ckpt and early stopping need to be disabled during overfitting
13
+ callbacks: null
fspen/configs/debug/profiler.yaml ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ # runs with execution time profiling
4
+
5
+ defaults:
6
+ - default
7
+
8
+ trainer:
9
+ max_epochs: 1
10
+ profiler: "simple"
11
+ # profiler: "advanced"
12
+ # profiler: "pytorch"
fspen/configs/eval.yaml ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ defaults:
4
+ - _self_
5
+ - data: speech_enhancement
6
+ - model: fspen
7
+ - callbacks: default
8
+ - logger: tensorboard
9
+ - trainer: gpu
10
+ - paths: eval
11
+ - extras: default
12
+ - hydra: default
13
+
14
+ task_name: "eval"
15
+
16
+ tags: ["dev"]
17
+
18
+ # passing checkpoint path is necessary for evaluation
19
+ ckpt_path: ???
fspen/configs/experiment/example.yaml ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ # to execute this experiment run:
4
+ # python train.py experiment=example
5
+
6
+ defaults:
7
+ - override /data: mnist
8
+ - override /model: mnist
9
+ - override /callbacks: default
10
+ - override /trainer: default
11
+
12
+ # all parameters below will be merged with parameters from default configurations set above
13
+ # this allows you to overwrite only specified parameters
14
+
15
+ tags: ["mnist", "simple_dense_net"]
16
+
17
+ seed: 12345
18
+
19
+ trainer:
20
+ min_epochs: 10
21
+ max_epochs: 10
22
+ gradient_clip_val: 0.5
23
+
24
+ model:
25
+ optimizer:
26
+ lr: 0.002
27
+ net:
28
+ lin1_size: 128
29
+ lin2_size: 256
30
+ lin3_size: 64
31
+ compile: false
32
+
33
+ data:
34
+ batch_size: 64
35
+
36
+ logger:
37
+ wandb:
38
+ tags: ${tags}
39
+ group: "mnist"
40
+ aim:
41
+ experiment: "mnist"
fspen/configs/extras/default.yaml ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ # disable python warnings if they annoy you
2
+ ignore_warnings: False
3
+
4
+ # ask user for tags if none are provided in the config
5
+ enforce_tags: True
6
+
7
+ # pretty print config tree at the start of the run using Rich library
8
+ print_config: True
fspen/configs/hparams_search/mnist_optuna.yaml ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ # example hyperparameter optimization of some experiment with Optuna:
4
+ # python train.py -m hparams_search=mnist_optuna experiment=example
5
+
6
+ defaults:
7
+ - override /hydra/sweeper: optuna
8
+
9
+ # choose metric which will be optimized by Optuna
10
+ # make sure this is the correct name of some metric logged in lightning module!
11
+ optimized_metric: "val/acc_best"
12
+
13
+ # here we define Optuna hyperparameter search
14
+ # it optimizes for value returned from function with @hydra.main decorator
15
+ # docs: https://hydra.cc/docs/next/plugins/optuna_sweeper
16
+ hydra:
17
+ mode: "MULTIRUN" # set hydra to multirun by default if this config is attached
18
+
19
+ sweeper:
20
+ _target_: hydra_plugins.hydra_optuna_sweeper.optuna_sweeper.OptunaSweeper
21
+
22
+ # storage URL to persist optimization results
23
+ # for example, you can use SQLite if you set 'sqlite:///example.db'
24
+ storage: null
25
+
26
+ # name of the study to persist optimization results
27
+ study_name: null
28
+
29
+ # number of parallel workers
30
+ n_jobs: 1
31
+
32
+ # 'minimize' or 'maximize' the objective
33
+ direction: maximize
34
+
35
+ # total number of runs that will be executed
36
+ n_trials: 20
37
+
38
+ # choose Optuna hyperparameter sampler
39
+ # you can choose bayesian sampler (tpe), random search (without optimization), grid sampler, and others
40
+ # docs: https://optuna.readthedocs.io/en/stable/reference/samplers.html
41
+ sampler:
42
+ _target_: optuna.samplers.TPESampler
43
+ seed: 1234
44
+ n_startup_trials: 10 # number of random sampling runs before optimization starts
45
+
46
+ # define hyperparameter search space
47
+ params:
48
+ model.optimizer.lr: interval(0.0001, 0.1)
49
+ data.batch_size: choice(32, 64, 128, 256)
50
+ model.net.lin1_size: choice(64, 128, 256)
51
+ model.net.lin2_size: choice(64, 128, 256)
52
+ model.net.lin3_size: choice(32, 64, 128, 256)
fspen/configs/hydra/default.yaml ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # https://hydra.cc/docs/configure_hydra/intro/
2
+
3
+ # enable color logging
4
+ defaults:
5
+ - override hydra_logging: colorlog
6
+ - override job_logging: colorlog
7
+
8
+ # output directory, generated dynamically on each run
9
+ run:
10
+ dir: ${paths.log_dir}/${task_name}/runs/${now:%Y-%m-%d}_${now:%H-%M-%S}
11
+ sweep:
12
+ dir: ${paths.log_dir}/${task_name}/multiruns/${now:%Y-%m-%d}_${now:%H-%M-%S}
13
+ subdir: ${hydra.job.num}
14
+
15
+ job_logging:
16
+ handlers:
17
+ file:
18
+ # Incorporates fix from https://github.com/facebookresearch/hydra/pull/2242
19
+ filename: ${hydra.runtime.output_dir}/${task_name}.log
fspen/configs/local/.gitkeep ADDED
File without changes
fspen/configs/logger/aim.yaml ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # https://aimstack.io/
2
+
3
+ # example usage in lightning module:
4
+ # https://github.com/aimhubio/aim/blob/main/examples/pytorch_lightning_track.py
5
+
6
+ # open the Aim UI with the following command (run in the folder containing the `.aim` folder):
7
+ # `aim up`
8
+
9
+ aim:
10
+ _target_: aim.pytorch_lightning.AimLogger
11
+ repo: ${paths.root_dir} # .aim folder will be created here
12
+ # repo: "aim://ip_address:port" # can instead provide IP address pointing to Aim remote tracking server which manages the repo, see https://aimstack.readthedocs.io/en/latest/using/remote_tracking.html#
13
+
14
+ # aim allows to group runs under experiment name
15
+ experiment: null # any string, set to "default" if not specified
16
+
17
+ train_metric_prefix: "train/"
18
+ val_metric_prefix: "val/"
19
+ test_metric_prefix: "test/"
20
+
21
+ # sets the tracking interval in seconds for system usage metrics (CPU, GPU, memory, etc.)
22
+ system_tracking_interval: 10 # set to null to disable system metrics tracking
23
+
24
+ # enable/disable logging of system params such as installed packages, git info, env vars, etc.
25
+ log_system_params: true
26
+
27
+ # enable/disable tracking console logs (default value is true)
28
+ capture_terminal_logs: false # set to false to avoid infinite console log loop issue https://github.com/aimhubio/aim/issues/2550
fspen/configs/logger/comet.yaml ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # https://www.comet.ml
2
+
3
+ comet:
4
+ _target_: lightning.pytorch.loggers.comet.CometLogger
5
+ api_key: ${oc.env:COMET_API_TOKEN} # api key is loaded from environment variable
6
+ save_dir: "${paths.output_dir}"
7
+ project_name: "lightning-hydra-template"
8
+ rest_api_key: null
9
+ # experiment_name: ""
10
+ experiment_key: null # set to resume experiment
11
+ offline: False
12
+ prefix: ""
fspen/configs/logger/csv.yaml ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ # csv logger built in lightning
2
+
3
+ csv:
4
+ _target_: lightning.pytorch.loggers.csv_logs.CSVLogger
5
+ save_dir: "${paths.output_dir}"
6
+ name: "csv/"
7
+ prefix: ""
fspen/configs/logger/many_loggers.yaml ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ # train with many loggers at once
2
+
3
+ defaults:
4
+ # - comet
5
+ - csv
6
+ # - mlflow
7
+ # - neptune
8
+ - tensorboard
9
+ - wandb
fspen/configs/logger/mlflow.yaml ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # https://mlflow.org
2
+
3
+ mlflow:
4
+ _target_: lightning.pytorch.loggers.mlflow.MLFlowLogger
5
+ # experiment_name: ""
6
+ # run_name: ""
7
+ tracking_uri: ${paths.log_dir}/mlflow/mlruns # run `mlflow ui` command inside the `logs/mlflow/` dir to open the UI
8
+ tags: null
9
+ # save_dir: "./mlruns"
10
+ prefix: ""
11
+ artifact_location: null
12
+ # run_id: ""
fspen/configs/logger/neptune.yaml ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ # https://neptune.ai
2
+
3
+ neptune:
4
+ _target_: lightning.pytorch.loggers.neptune.NeptuneLogger
5
+ api_key: ${oc.env:NEPTUNE_API_TOKEN} # api key is loaded from environment variable
6
+ project: username/lightning-hydra-template
7
+ # name: ""
8
+ log_model_checkpoints: True
9
+ prefix: ""
fspen/configs/logger/tensorboard.yaml ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ # https://www.tensorflow.org/tensorboard/
2
+
3
+ tensorboard:
4
+ _target_: lightning.pytorch.loggers.tensorboard.TensorBoardLogger
5
+ save_dir: "${paths.output_dir}/tensorboard/"
6
+ name: null
7
+ log_graph: False
8
+ default_hp_metric: True
9
+ prefix: ""
10
+ # version: ""
fspen/configs/logger/wandb.yaml ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # https://wandb.ai
2
+
3
+ wandb:
4
+ _target_: lightning.pytorch.loggers.wandb.WandbLogger
5
+ # name: "" # name of the run (normally generated by wandb)
6
+ save_dir: "${paths.output_dir}"
7
+ offline: False
8
+ id: null # pass correct id to resume experiment!
9
+ anonymous: null # enable anonymous logging
10
+ project: "lightning-hydra-template"
11
+ log_model: False # upload lightning ckpts
12
+ prefix: "" # a string to put at the beginning of metric keys
13
+ # entity: "" # set to name of your wandb team
14
+ group: ""
15
+ tags: []
16
+ job_type: ""
fspen/configs/model/fspen.yaml ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _target_: src.models.fspen_module.FSPENLitModule
2
+
3
+ optimizer:
4
+ _target_: torch.optim.Adam
5
+ _partial_: true
6
+ lr: 0.001
7
+ weight_decay: 0.0
8
+
9
+ scheduler:
10
+ _target_: torch.optim.lr_scheduler.ReduceLROnPlateau
11
+ _partial_: true
12
+ mode: min
13
+ factor: 0.1
14
+ patience: 10
15
+
16
+ net:
17
+ _target_: src.models.fspen.FSPEN
18
+
19
+ # compile model for faster training with pytorch 2.0
20
+ compile: false
21
+
22
+ criterion:
23
+ _target_: src.models.components.loss.MultiResolutionSTFTLoss
24
+ fft_sizes: [512, 1024, 2048]
fspen/configs/paths/default.yaml ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # path to root directory
2
+ # this requires PROJECT_ROOT environment variable to exist
3
+ # you can replace it with "." if you want the root to be the current working directory
4
+ root_dir: ${oc.env:PROJECT_ROOT}
5
+
6
+ # path to data directory
7
+ noisy_dir: ${paths.root_dir}/data/noisy/
8
+ clean_dir: ${paths.root_dir}/data/clean/
9
+
10
+ # path to logging directory
11
+ log_dir: ${paths.root_dir}/logs/
12
+
13
+ # path to output directory, created dynamically by hydra
14
+ # path generation pattern is specified in `configs/hydra/default.yaml`
15
+ # use it to store all files generated during the run, like ckpts and metrics
16
+ output_dir: ${hydra:runtime.output_dir}
17
+
18
+ # path to working directory
19
+ work_dir: ${hydra:runtime.cwd}
fspen/configs/paths/eval.yaml ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # path to root directory
2
+ # this requires PROJECT_ROOT environment variable to exist
3
+ # you can replace it with "." if you want the root to be the current working directory
4
+ root_dir: ${oc.env:PROJECT_ROOT}
5
+
6
+ # path to data directory
7
+ noisy_dir: ${paths.root_dir}/voicebank_wavs/test/noisy/
8
+ clean_dir: ${paths.root_dir}/voicebank_wavs/test/clean/
9
+
10
+ # path to logging directory
11
+ log_dir: ${paths.root_dir}/logs/
12
+
13
+ # path to output directory, created dynamically by hydra
14
+ # path generation pattern is specified in `configs/hydra/default.yaml`
15
+ # use it to store all files generated during the run, like ckpts and metrics
16
+ output_dir: ${hydra:runtime.output_dir}
17
+
18
+ # path to working directory
19
+ work_dir: ${hydra:runtime.cwd}
fspen/configs/train.yaml ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ # specify here default configuration
4
+ # order of defaults determines the order in which configs override each other
5
+ defaults:
6
+ - _self_
7
+ - data: speech_enhancement
8
+ - model: fspen
9
+ - callbacks: default
10
+ - logger: tensorboard
11
+ - trainer: gpu
12
+ - paths: default
13
+ - extras: default
14
+ - hydra: default
15
+
16
+ # experiment configs allow for version control of specific hyperparameters
17
+ # e.g. best hyperparameters for given model and datamodule
18
+ - experiment: null
19
+
20
+ # config for hyperparameter optimization
21
+ - hparams_search: null
22
+
23
+ # optional local config for machine/user specific settings
24
+ # it's optional since it doesn't need to exist and is excluded from version control
25
+ - optional local: default
26
+
27
+ # debugging config (enable through command line, e.g. `python train.py debug=default)
28
+ - debug: null
29
+
30
+ # task name, determines output directory path
31
+ task_name: "train"
32
+
33
+ # tags to help you identify your experiments
34
+ # you can overwrite this in experiment configs
35
+ # overwrite from command line with `python train.py tags="[first_tag, second_tag]"`
36
+ tags: ["dev"]
37
+
38
+ # set False to skip model training
39
+ train: True
40
+
41
+ # evaluate on test set, using best model weights achieved during training
42
+ # lightning chooses best weights based on the metric specified in checkpoint callback
43
+ test: False
44
+
45
+ # simply provide checkpoint path to resume training
46
+ ckpt_path: null
47
+
48
+ # seed for random number generators in pytorch, numpy and python.random
49
+ seed: null
fspen/configs/trainer/cpu.yaml ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ defaults:
2
+ - default
3
+
4
+ accelerator: cpu
5
+ devices: 1
fspen/configs/trainer/ddp.yaml ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ defaults:
2
+ - default
3
+
4
+ strategy: ddp
5
+
6
+ accelerator: gpu
7
+ devices: 4
8
+ num_nodes: 1
9
+ sync_batchnorm: True
fspen/configs/trainer/ddp_sim.yaml ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ defaults:
2
+ - default
3
+
4
+ # simulate DDP on CPU, useful for debugging
5
+ accelerator: cpu
6
+ devices: 2
7
+ strategy: ddp_spawn
fspen/configs/trainer/default.yaml ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _target_: lightning.pytorch.trainer.Trainer
2
+
3
+ default_root_dir: ${paths.output_dir}
4
+
5
+ min_epochs: 1 # prevents early stopping
6
+ max_epochs: 100
7
+
8
+ accelerator: cpu
9
+ devices: 1
10
+
11
+ # mixed precision for extra speed-up
12
+ # precision: 16
13
+
14
+ # perform a validation loop every N training epochs
15
+ check_val_every_n_epoch: 1
16
+
17
+ # set True to to ensure deterministic results
18
+ # makes training slower but gives more reproducibility than just setting seeds
19
+ deterministic: False
fspen/configs/trainer/gpu.yaml ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ defaults:
2
+ - default
3
+
4
+ accelerator: gpu
5
+ devices: 1
fspen/configs/trainer/mps.yaml ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ defaults:
2
+ - default
3
+
4
+ accelerator: mps
5
+ devices: 1
fspen/environment.yaml ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ##############
2
+ ## This environment uses old versions of dl libraries because the training are done on an old GPU
3
+ ## mainly torch==2.0.1+cu118
4
+ ##############
5
+ name: fspen
6
+ channels:
7
+ - defaults
8
+ - https://repo.anaconda.com/pkgs/main
9
+ - https://repo.anaconda.com/pkgs/r
10
+ dependencies:
11
+ - _libgcc_mutex=0.1
12
+ - _openmp_mutex=5.1
13
+ - bzip2=1.0.8
14
+ - ca-certificates=2025.9.9
15
+ - expat=2.7.1
16
+ - ld_impl_linux-64=2.44
17
+ - libffi=3.4.4
18
+ - libgcc-ng=11.2.0
19
+ - libgomp=11.2.0
20
+ - libnsl=2.0.0
21
+ - libstdcxx-ng=11.2.0
22
+ - libuuid=1.41.5
23
+ - libxcb=1.17.0
24
+ - libzlib=1.3.1
25
+ - ncurses=6.5
26
+ - openssl=3.0.18
27
+ - pip=25.2
28
+ - pthread-stubs=0.3
29
+ - python=3.10.19
30
+ - readline=8.3
31
+ - setuptools=80.9.0
32
+ - sqlite=3.50.2
33
+ - tk=8.6.15
34
+ - wheel=0.45.1
35
+ - xorg-libx11=1.8.12
36
+ - xorg-libxau=1.0.12
37
+ - xorg-libxdmcp=1.1.5
38
+ - xorg-xorgproto=2024.1
39
+ - xz=5.6.4
40
+ - zlib=1.3.1
41
+ - pip:
42
+ - absl-py==2.3.1
43
+ - aiohappyeyeballs==2.6.1
44
+ - aiohttp==3.13.2
45
+ - aiosignal==1.4.0
46
+ - antlr4-python3-runtime==4.9.3
47
+ - async-timeout==5.0.1
48
+ - certifi==2025.10.5
49
+ - charset-normalizer==3.4.4
50
+ - cmake==3.25.0
51
+ - contourpy==1.3.2
52
+ - cycler==0.12.1
53
+ - decorator==5.2.1
54
+ - deprecate==1.0.5
55
+ - dotenv==0.9.9
56
+ - einops==0.8.1
57
+ - filelock==3.19.1
58
+ - fonttools==4.60.1
59
+ - frozenlist==1.8.0
60
+ - fsspec==2025.9.0
61
+ - grpcio==1.76.0
62
+ - h5py==3.15.1
63
+ - hf-xet==1.1.10
64
+ - huggingface-hub==0.35.3
65
+ - hydra-colorlog==1.2.0
66
+ - hydra-core==1.3.2
67
+ - hydra-optuna-sweeper==1.2.0
68
+ - idna==3.11
69
+ - indic-numtowords==1.1.0
70
+ - iniconfig==2.3.0
71
+ - ipdb==0.13.13
72
+ - jinja2==3.1.6
73
+ - kiwisolver==1.4.9
74
+ - lightning==2.3.0
75
+ - lightning-utilities==0.15.2
76
+ - lit==15.0.7
77
+ - markdown==3.9
78
+ - markupsafe==2.1.5
79
+ - matplotlib==3.10.7
80
+ - more-itertools==10.8.0
81
+ - mpmath==1.3.0
82
+ - multidict==6.7.0
83
+ - numpy==1.26.4
84
+ - omegaconf==2.3.0
85
+ - pandas==2.3.3
86
+ - pexpect==4.9.0
87
+ - pillow==12.0.0
88
+ - pluggy==1.6.0
89
+ - pre-commit==4.5.0
90
+ - propcache==0.4.1
91
+ - protobuf==6.33.0
92
+ - ptyprocess==0.7.0
93
+ - pygments==2.19.2
94
+ - pyparsing==3.2.5
95
+ - pytest==9.0.1
96
+ - python-dotenv==1.2.1
97
+ - pytz==2025.2
98
+ - pyyaml==6.0.3
99
+ - regex==2025.10.23
100
+ - rich==14.2.0
101
+ - rootutils==1.0.7
102
+ - safetensors==0.6.2
103
+ - scipy==1.15.3
104
+ - six==1.17.0
105
+ - soundfile==0.13.1
106
+ - sympy==1.14.0
107
+ - tensorboard==2.20.0
108
+ - tensorboard-data-server==0.7.2
109
+ - tokenizers==0.22.1
110
+ - torch==2.0.1+cu118
111
+ - torch-tb-profiler==0.4.3
112
+ - torchaudio==2.0.2+cu118
113
+ - torchinfo==1.8.0
114
+ - torchmetrics==0.11.4
115
+ - torchsummary==1.5.1
116
+ - tqdm==4.67.1
117
+ - transformers==4.57.1
118
+ - triton==2.0.0
119
+ - typing-extensions==4.15.0
120
+ - tzdata==2025.2
121
+ - urllib3==2.5.0
122
+ - werkzeug==3.1.3
123
+ - xxhash==3.6.0
124
+ - yarl==1.22.0
125
+ prefix: /home/infres/lasri-22/miniconda3/envs/fspen
fspen/notebooks/.gitkeep ADDED
File without changes
fspen/pyproject.toml ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [tool.pytest.ini_options]
2
+ addopts = [
3
+ "--color=yes",
4
+ "--durations=0",
5
+ "--strict-markers",
6
+ "--doctest-modules",
7
+ ]
8
+ filterwarnings = [
9
+ "ignore::DeprecationWarning",
10
+ "ignore::UserWarning",
11
+ ]
12
+ log_cli = "True"
13
+ markers = [
14
+ "slow: slow tests",
15
+ ]
16
+ minversion = "6.0"
17
+ testpaths = "tests/"
18
+
19
+ [tool.coverage.report]
20
+ exclude_lines = [
21
+ "pragma: nocover",
22
+ "raise NotImplementedError",
23
+ "raise NotImplementedError()",
24
+ "if __name__ == .__main__.:",
25
+ ]
fspen/requirements.txt ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------- pytorch --------- #
2
+ torch>=2.0.0
3
+ torchvision>=0.15.0
4
+ lightning>=2.0.0
5
+ torchmetrics>=0.11.4
6
+
7
+ # --------- hydra --------- #
8
+ hydra-core==1.3.2
9
+ hydra-colorlog==1.2.0
10
+ hydra-optuna-sweeper==1.2.0
11
+
12
+ # --------- loggers --------- #
13
+ # wandb
14
+ # neptune-client
15
+ # mlflow
16
+ # comet-ml
17
+ # aim>=3.16.2 # no lower than 3.16.2, see https://github.com/aimhubio/aim/issues/2550
18
+
19
+ # --------- others --------- #
20
+ rootutils # standardizing the project root setup
21
+ pre-commit # hooks for applying linters on commit
22
+ rich # beautiful text formatting in terminal
23
+ pytest # tests
24
+ # sh # for running bash commands in some tests (linux/macos only)
fspen/scripts/schedule.sh ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ # Schedule execution of many runs
3
+ # Run from root folder with: bash scripts/schedule.sh
4
+
5
+ python src/train.py trainer.max_epochs=5 logger=csv
6
+
7
+ python src/train.py trainer.max_epochs=10 logger=csv
fspen/setup.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ from setuptools import find_packages, setup
4
+
5
+ setup(
6
+ name="src",
7
+ version="0.0.1",
8
+ description="Describe Your Cool Project",
9
+ author="",
10
+ author_email="",
11
+ url="https://github.com/user/project",
12
+ install_requires=["lightning", "hydra-core"],
13
+ packages=find_packages(),
14
+ # use this to customize global commands available in the terminal after installing the package
15
+ entry_points={
16
+ "console_scripts": [
17
+ "train_command = src.train:main",
18
+ "eval_command = src.eval:main",
19
+ ]
20
+ },
21
+ )