Spaces:
Build error
Build error
Iliass Lasri commited on
Commit ·
27d7586
1
Parent(s): 54dc2f8
added all files
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- README.md +7 -6
- app.py +285 -0
- fspen/.gitignore +158 -0
- fspen/.pre-commit-config.yaml +134 -0
- fspen/.project-root +2 -0
- fspen/Makefile +30 -0
- fspen/README.md +94 -0
- fspen/configs/__init__.py +1 -0
- fspen/configs/callbacks/default.yaml +23 -0
- fspen/configs/callbacks/early_stopping.yaml +15 -0
- fspen/configs/callbacks/model_checkpoint.yaml +17 -0
- fspen/configs/callbacks/model_summary.yaml +5 -0
- fspen/configs/callbacks/none.yaml +0 -0
- fspen/configs/callbacks/rich_progress_bar.yaml +4 -0
- fspen/configs/data/speech_enhancement.yaml +13 -0
- fspen/configs/debug/default.yaml +35 -0
- fspen/configs/debug/fdr.yaml +9 -0
- fspen/configs/debug/limit.yaml +12 -0
- fspen/configs/debug/overfit.yaml +13 -0
- fspen/configs/debug/profiler.yaml +12 -0
- fspen/configs/eval.yaml +19 -0
- fspen/configs/experiment/example.yaml +41 -0
- fspen/configs/extras/default.yaml +8 -0
- fspen/configs/hparams_search/mnist_optuna.yaml +52 -0
- fspen/configs/hydra/default.yaml +19 -0
- fspen/configs/local/.gitkeep +0 -0
- fspen/configs/logger/aim.yaml +28 -0
- fspen/configs/logger/comet.yaml +12 -0
- fspen/configs/logger/csv.yaml +7 -0
- fspen/configs/logger/many_loggers.yaml +9 -0
- fspen/configs/logger/mlflow.yaml +12 -0
- fspen/configs/logger/neptune.yaml +9 -0
- fspen/configs/logger/tensorboard.yaml +10 -0
- fspen/configs/logger/wandb.yaml +16 -0
- fspen/configs/model/fspen.yaml +24 -0
- fspen/configs/paths/default.yaml +19 -0
- fspen/configs/paths/eval.yaml +19 -0
- fspen/configs/train.yaml +49 -0
- fspen/configs/trainer/cpu.yaml +5 -0
- fspen/configs/trainer/ddp.yaml +9 -0
- fspen/configs/trainer/ddp_sim.yaml +7 -0
- fspen/configs/trainer/default.yaml +19 -0
- fspen/configs/trainer/gpu.yaml +5 -0
- fspen/configs/trainer/mps.yaml +5 -0
- fspen/environment.yaml +125 -0
- fspen/notebooks/.gitkeep +0 -0
- fspen/pyproject.toml +25 -0
- fspen/requirements.txt +24 -0
- fspen/scripts/schedule.sh +7 -0
- fspen/setup.py +21 -0
README.md
CHANGED
|
@@ -1,12 +1,13 @@
|
|
| 1 |
---
|
| 2 |
-
title:
|
| 3 |
-
emoji:
|
| 4 |
-
colorFrom:
|
| 5 |
-
colorTo:
|
| 6 |
sdk: gradio
|
| 7 |
-
sdk_version: 6.2.0
|
| 8 |
app_file: app.py
|
|
|
|
| 9 |
pinned: false
|
|
|
|
| 10 |
---
|
| 11 |
|
| 12 |
-
Check out the configuration reference at https://huggingface.co/docs/hub/spaces
|
|
|
|
| 1 |
---
|
| 2 |
+
title: DeepFilterNet2
|
| 3 |
+
emoji: 💩
|
| 4 |
+
colorFrom: gray
|
| 5 |
+
colorTo: red
|
| 6 |
sdk: gradio
|
|
|
|
| 7 |
app_file: app.py
|
| 8 |
+
sdk_version: 3.17.1
|
| 9 |
pinned: false
|
| 10 |
+
license: apache-2.0
|
| 11 |
---
|
| 12 |
|
| 13 |
+
Check out the configuration reference at https://huggingface.co/docs/hub/spaces#reference
|
app.py
ADDED
|
@@ -0,0 +1,285 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import glob
|
| 2 |
+
import math
|
| 3 |
+
import os
|
| 4 |
+
import sys
|
| 5 |
+
import tempfile
|
| 6 |
+
import time
|
| 7 |
+
from typing import List, Optional, Tuple, Union
|
| 8 |
+
|
| 9 |
+
import gradio as gr
|
| 10 |
+
import matplotlib.pyplot as plt
|
| 11 |
+
import numpy as np
|
| 12 |
+
import torch
|
| 13 |
+
import torchaudio
|
| 14 |
+
import torchaudio.transforms as T
|
| 15 |
+
from loguru import logger
|
| 16 |
+
from PIL import Image
|
| 17 |
+
|
| 18 |
+
sys.path.append("fspen")
|
| 19 |
+
from fspen.src.test import enhance_audio
|
| 20 |
+
|
| 21 |
+
CHECKPOINT_PATH = ""
|
| 22 |
+
TARGET_SR = 16000
|
| 23 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 24 |
+
|
| 25 |
+
# --- PLOTTING SETUP ---
|
| 26 |
+
fig_noisy: plt.Figure
|
| 27 |
+
fig_enh: plt.Figure
|
| 28 |
+
ax_noisy: plt.Axes
|
| 29 |
+
ax_enh: plt.Axes
|
| 30 |
+
fig_noisy, ax_noisy = plt.subplots(figsize=(15.2, 4))
|
| 31 |
+
fig_noisy.set_tight_layout(True)
|
| 32 |
+
fig_enh, ax_enh = plt.subplots(figsize=(15.2, 4))
|
| 33 |
+
fig_enh.set_tight_layout(True)
|
| 34 |
+
|
| 35 |
+
NOISES = {
|
| 36 |
+
"None": None,
|
| 37 |
+
"Kitchen": "samples/dkitchen.wav",
|
| 38 |
+
"Living Room": "samples/dliving.wav",
|
| 39 |
+
"River": "samples/nriver.wav",
|
| 40 |
+
"Cafe": "samples/scafe.wav",
|
| 41 |
+
}
|
| 42 |
+
|
| 43 |
+
# --- HELPER FUNCTIONS ---
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def load_audio_torch(path, target_sr=TARGET_SR):
|
| 47 |
+
"""Replacement for df.load_audio using torchaudio"""
|
| 48 |
+
if path is None:
|
| 49 |
+
return None, None
|
| 50 |
+
|
| 51 |
+
sig, sr = torchaudio.load(path)
|
| 52 |
+
if sr != target_sr:
|
| 53 |
+
resampler = T.Resample(sr, target_sr)
|
| 54 |
+
sig = resampler(sig)
|
| 55 |
+
return sig, target_sr
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def save_audio_torch(path, tensor, sr):
|
| 59 |
+
"""Replacement for df.save_audio using torchaudio"""
|
| 60 |
+
# Ensure tensor is on CPU
|
| 61 |
+
tensor = tensor.detach().cpu()
|
| 62 |
+
# Check shape [channels, time], torchaudio expects this
|
| 63 |
+
if tensor.dim() == 1:
|
| 64 |
+
tensor = tensor.unsqueeze(0)
|
| 65 |
+
torchaudio.save(path, tensor, sr)
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def mix_at_snr(clean, noise, snr, eps=1e-10):
|
| 69 |
+
"""Mix clean and noise signal at a given SNR."""
|
| 70 |
+
# Standardize to (1, T)
|
| 71 |
+
if clean.dim() == 1:
|
| 72 |
+
clean = clean.unsqueeze(0)
|
| 73 |
+
if noise.dim() == 1:
|
| 74 |
+
noise = noise.unsqueeze(0)
|
| 75 |
+
|
| 76 |
+
clean = clean.mean(0, keepdim=True)
|
| 77 |
+
noise = noise.mean(0, keepdim=True)
|
| 78 |
+
|
| 79 |
+
if noise.shape[1] < clean.shape[1]:
|
| 80 |
+
noise = noise.repeat((1, int(math.ceil(clean.shape[1] / noise.shape[1]))))
|
| 81 |
+
max_start = int(noise.shape[1] - clean.shape[1])
|
| 82 |
+
start = torch.randint(0, max_start, ()).item() if max_start > 0 else 0
|
| 83 |
+
noise = noise[:, start : start + clean.shape[1]]
|
| 84 |
+
|
| 85 |
+
E_speech = torch.mean(clean.pow(2)) + eps
|
| 86 |
+
E_noise = torch.mean(noise.pow(2))
|
| 87 |
+
K = torch.sqrt((E_noise / E_speech) * 10 ** (snr / 10) + eps)
|
| 88 |
+
noise = noise / K
|
| 89 |
+
mixture = clean + noise
|
| 90 |
+
|
| 91 |
+
max_m = mixture.abs().max()
|
| 92 |
+
if max_m > 1:
|
| 93 |
+
clean, noise, mixture = clean / max_m, noise / max_m, mixture / max_m
|
| 94 |
+
return clean, noise, mixture
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def specshow(
|
| 98 |
+
spec,
|
| 99 |
+
ax=None,
|
| 100 |
+
title=None,
|
| 101 |
+
sr=48000,
|
| 102 |
+
n_fft=None,
|
| 103 |
+
hop=None,
|
| 104 |
+
t=None,
|
| 105 |
+
f=None,
|
| 106 |
+
vmin=-100,
|
| 107 |
+
vmax=0,
|
| 108 |
+
cmap="inferno",
|
| 109 |
+
):
|
| 110 |
+
"""Plots a spectrogram of shape [F, T]"""
|
| 111 |
+
spec_np = spec.cpu().numpy() if isinstance(spec, torch.Tensor) else spec
|
| 112 |
+
if ax is None:
|
| 113 |
+
ax = plt
|
| 114 |
+
|
| 115 |
+
if n_fft is None:
|
| 116 |
+
n_fft = (spec.shape[0] - 1) * 2
|
| 117 |
+
hop = hop or n_fft // 4
|
| 118 |
+
|
| 119 |
+
if t is None:
|
| 120 |
+
t = np.arange(0, spec_np.shape[-1]) * hop / sr
|
| 121 |
+
if f is None:
|
| 122 |
+
f = np.arange(0, spec_np.shape[0]) * sr // 2 / (n_fft // 2) / 1000
|
| 123 |
+
|
| 124 |
+
im = ax.pcolormesh(
|
| 125 |
+
t, f, spec_np, rasterized=True, shading="auto", vmin=vmin, vmax=vmax, cmap=cmap
|
| 126 |
+
)
|
| 127 |
+
if title:
|
| 128 |
+
ax.set_title(title)
|
| 129 |
+
ax.set_xlabel("Time [s]")
|
| 130 |
+
ax.set_ylabel("Frequency [kHz]")
|
| 131 |
+
return im
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
def spec_im(audio: torch.Tensor, sr=TARGET_SR, figsize=(15, 5), figure=None, ax=None) -> Image:
|
| 135 |
+
audio = torch.as_tensor(audio)
|
| 136 |
+
if audio.dim() > 1:
|
| 137 |
+
audio = audio.mean(dim=0) # Mix to mono for spec
|
| 138 |
+
|
| 139 |
+
n_fft = 1024
|
| 140 |
+
hop = 512
|
| 141 |
+
w = torch.hann_window(n_fft, device=audio.device)
|
| 142 |
+
spec = torch.stft(audio, n_fft, hop, window=w, return_complex=False)
|
| 143 |
+
spec = spec.div_(w.pow(2).sum())
|
| 144 |
+
spec = torch.view_as_complex(spec).abs().clamp_min(1e-12).log10().mul(10)
|
| 145 |
+
|
| 146 |
+
if figure is None:
|
| 147 |
+
figure = plt.figure(figsize=figsize)
|
| 148 |
+
figure.set_tight_layout(True)
|
| 149 |
+
|
| 150 |
+
if spec.dim() > 2:
|
| 151 |
+
spec = spec.squeeze(0)
|
| 152 |
+
specshow(spec, ax=ax, sr=sr, n_fft=n_fft, hop=hop)
|
| 153 |
+
|
| 154 |
+
figure.canvas.draw()
|
| 155 |
+
return Image.frombytes("RGB", figure.canvas.get_width_height(), figure.canvas.tostring_rgb())
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
def cleanup_tmp(filter_list: List[str] = [], hours_keep=2):
|
| 159 |
+
# Basic cleanup logic
|
| 160 |
+
if os.path.exists("/tmp"):
|
| 161 |
+
for f in glob.glob("/tmp/*wav"):
|
| 162 |
+
# Only delete if very old or explicitly temp
|
| 163 |
+
pass
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
# --- MAIN DEMO FUNCTION ---
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
def demo_fn(speech_upl: str, noise_type: str, snr: int, mic_input: Optional[str] = None):
|
| 170 |
+
if mic_input:
|
| 171 |
+
speech_upl = mic_input
|
| 172 |
+
|
| 173 |
+
sr = TARGET_SR
|
| 174 |
+
logger.info(f"Params: speech={speech_upl}, noise={noise_type}, snr={snr}")
|
| 175 |
+
snr = int(snr)
|
| 176 |
+
noise_fn = NOISES[noise_type]
|
| 177 |
+
|
| 178 |
+
# 1. Load Clean Speech
|
| 179 |
+
max_s = 10
|
| 180 |
+
if speech_upl is not None:
|
| 181 |
+
sample, _ = load_audio_torch(speech_upl, sr)
|
| 182 |
+
max_len = max_s * sr
|
| 183 |
+
if sample.shape[-1] > max_len:
|
| 184 |
+
start = torch.randint(0, sample.shape[-1] - max_len, ()).item()
|
| 185 |
+
sample = sample[..., start : start + max_len]
|
| 186 |
+
else:
|
| 187 |
+
# Fallback sample
|
| 188 |
+
sample, _ = load_audio_torch("samples/p232_013_clean.wav", sr)
|
| 189 |
+
sample = sample[..., : max_s * sr]
|
| 190 |
+
|
| 191 |
+
# Ensure channels first
|
| 192 |
+
if sample.dim() > 1 and sample.shape[0] > 1:
|
| 193 |
+
sample = sample.mean(dim=0, keepdim=True)
|
| 194 |
+
|
| 195 |
+
# 2. Add Noise (if selected)
|
| 196 |
+
if noise_fn is not None:
|
| 197 |
+
noise, _ = load_audio_torch(noise_fn, sr)
|
| 198 |
+
_, _, sample = mix_at_snr(sample, noise, snr)
|
| 199 |
+
|
| 200 |
+
# 3. Save Noisy File (Input for enhance_audio)
|
| 201 |
+
noisy_wav_path = tempfile.NamedTemporaryFile(suffix="noisy.wav", delete=False).name
|
| 202 |
+
save_audio_torch(noisy_wav_path, sample, sr)
|
| 203 |
+
|
| 204 |
+
# 4. Run Inference using your Custom Function
|
| 205 |
+
enhanced_wav_path = tempfile.NamedTemporaryFile(suffix="enhanced.wav", delete=False).name
|
| 206 |
+
|
| 207 |
+
logger.info("Starting enhancement...")
|
| 208 |
+
# CALLING YOUR MODEL HERE
|
| 209 |
+
enhance_audio(CHECKPOINT_PATH, noisy_wav_path, enhanced_wav_path)
|
| 210 |
+
logger.info("Enhancement finished")
|
| 211 |
+
|
| 212 |
+
# 5. Load Enhanced Audio for Visualization
|
| 213 |
+
enhanced, _ = load_audio_torch(enhanced_wav_path, sr)
|
| 214 |
+
|
| 215 |
+
# 6. Generate Visuals
|
| 216 |
+
ax_noisy.clear()
|
| 217 |
+
ax_enh.clear()
|
| 218 |
+
noisy_im = spec_im(sample, sr=sr, figure=fig_noisy, ax=ax_noisy)
|
| 219 |
+
enh_im = spec_im(enhanced, sr=sr, figure=fig_enh, ax=ax_enh)
|
| 220 |
+
|
| 221 |
+
return noisy_wav_path, noisy_im, enhanced_wav_path, enh_im
|
| 222 |
+
|
| 223 |
+
|
| 224 |
+
def toggle(choice):
|
| 225 |
+
if choice == "mic":
|
| 226 |
+
return gr.update(visible=True, value=None), gr.update(visible=False, value=None)
|
| 227 |
+
else:
|
| 228 |
+
return gr.update(visible=False, value=None), gr.update(visible=True, value=None)
|
| 229 |
+
|
| 230 |
+
|
| 231 |
+
# --- GRADIO INTERFACE ---
|
| 232 |
+
|
| 233 |
+
with gr.Blocks() as demo:
|
| 234 |
+
with gr.Row():
|
| 235 |
+
gr.Markdown(
|
| 236 |
+
"""
|
| 237 |
+
## Audio Enhancement Demo (Custom Model)
|
| 238 |
+
Upload audio or record from mic to test the model in `mva-proj`.
|
| 239 |
+
"""
|
| 240 |
+
)
|
| 241 |
+
with gr.Row():
|
| 242 |
+
with gr.Column():
|
| 243 |
+
radio = gr.Radio(["mic", "file"], value="file", label="Audio Source")
|
| 244 |
+
mic_input = gr.Mic(label="Microphone Input", type="filepath", visible=False)
|
| 245 |
+
audio_file = gr.Audio(type="filepath", label="File Input", visible=True)
|
| 246 |
+
inputs = [
|
| 247 |
+
audio_file,
|
| 248 |
+
gr.Dropdown(
|
| 249 |
+
label="Add background noise",
|
| 250 |
+
choices=list(NOISES.keys()),
|
| 251 |
+
value="None",
|
| 252 |
+
),
|
| 253 |
+
gr.Dropdown(
|
| 254 |
+
label="Noise Level (SNR)",
|
| 255 |
+
choices=["-5", "0", "10", "20"],
|
| 256 |
+
value="10",
|
| 257 |
+
),
|
| 258 |
+
mic_input,
|
| 259 |
+
]
|
| 260 |
+
btn = gr.Button("Denoise", variant="primary")
|
| 261 |
+
with gr.Column():
|
| 262 |
+
outputs = [
|
| 263 |
+
gr.Audio(type="filepath", label="Noisy Input"),
|
| 264 |
+
gr.Image(label="Noisy Spectrogram"),
|
| 265 |
+
gr.Audio(type="filepath", label="Enhanced Output"),
|
| 266 |
+
gr.Image(label="Enhanced Spectrogram"),
|
| 267 |
+
]
|
| 268 |
+
|
| 269 |
+
btn.click(fn=demo_fn, inputs=inputs, outputs=outputs, api_name="denoise")
|
| 270 |
+
radio.change(toggle, radio, [mic_input, audio_file])
|
| 271 |
+
|
| 272 |
+
# Examples (Ensure these files exist in your folder)
|
| 273 |
+
if os.path.exists("samples/p232_013_clean.wav"):
|
| 274 |
+
gr.Examples(
|
| 275 |
+
[
|
| 276 |
+
["samples/p232_013_clean.wav", "Kitchen", "10"],
|
| 277 |
+
["samples/p232_013_clean.wav", "Cafe", "10"],
|
| 278 |
+
],
|
| 279 |
+
fn=demo_fn,
|
| 280 |
+
inputs=inputs,
|
| 281 |
+
outputs=outputs,
|
| 282 |
+
cache_examples=False, # Disable cache if model changes frequently
|
| 283 |
+
)
|
| 284 |
+
|
| 285 |
+
demo.launch(enable_queue=True)
|
fspen/.gitignore
ADDED
|
@@ -0,0 +1,158 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Byte-compiled / optimized / DLL files
|
| 2 |
+
__pycache__/
|
| 3 |
+
*.py[cod]
|
| 4 |
+
*$py.class
|
| 5 |
+
|
| 6 |
+
# C extensions
|
| 7 |
+
*.so
|
| 8 |
+
|
| 9 |
+
# Distribution / packaging
|
| 10 |
+
.Python
|
| 11 |
+
build/
|
| 12 |
+
./data/
|
| 13 |
+
develop-eggs/
|
| 14 |
+
dist/
|
| 15 |
+
downloads/
|
| 16 |
+
eggs/
|
| 17 |
+
.eggs/
|
| 18 |
+
lib/
|
| 19 |
+
lib64/
|
| 20 |
+
parts/
|
| 21 |
+
sdist/
|
| 22 |
+
var/
|
| 23 |
+
wheels/
|
| 24 |
+
pip-wheel-metadata/
|
| 25 |
+
share/python-wheels/
|
| 26 |
+
*.egg-info/
|
| 27 |
+
.installed.cfg
|
| 28 |
+
*.egg
|
| 29 |
+
MANIFEST
|
| 30 |
+
|
| 31 |
+
# PyInstaller
|
| 32 |
+
# Usually these files are written by a python script from a template
|
| 33 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
| 34 |
+
*.manifest
|
| 35 |
+
*.spec
|
| 36 |
+
|
| 37 |
+
# Installer logs
|
| 38 |
+
pip-log.txt
|
| 39 |
+
pip-delete-this-directory.txt
|
| 40 |
+
|
| 41 |
+
# Unit test / coverage reports
|
| 42 |
+
htmlcov/
|
| 43 |
+
.tox/
|
| 44 |
+
.nox/
|
| 45 |
+
.coverage
|
| 46 |
+
.coverage.*
|
| 47 |
+
.cache
|
| 48 |
+
nosetests.xml
|
| 49 |
+
coverage.xml
|
| 50 |
+
*.cover
|
| 51 |
+
*.py,cover
|
| 52 |
+
.hypothesis/
|
| 53 |
+
.pytest_cache/
|
| 54 |
+
|
| 55 |
+
# Translations
|
| 56 |
+
*.mo
|
| 57 |
+
*.pot
|
| 58 |
+
|
| 59 |
+
# Django stuff:
|
| 60 |
+
*.log
|
| 61 |
+
local_settings.py
|
| 62 |
+
db.sqlite3
|
| 63 |
+
db.sqlite3-journal
|
| 64 |
+
|
| 65 |
+
# Flask stuff:
|
| 66 |
+
instance/
|
| 67 |
+
.webassets-cache
|
| 68 |
+
|
| 69 |
+
# Scrapy stuff:
|
| 70 |
+
.scrapy
|
| 71 |
+
|
| 72 |
+
# Sphinx documentation
|
| 73 |
+
docs/_build/
|
| 74 |
+
|
| 75 |
+
# PyBuilder
|
| 76 |
+
target/
|
| 77 |
+
|
| 78 |
+
# Jupyter Notebook
|
| 79 |
+
.ipynb_checkpoints
|
| 80 |
+
|
| 81 |
+
# IPython
|
| 82 |
+
profile_default/
|
| 83 |
+
ipython_config.py
|
| 84 |
+
|
| 85 |
+
# pyenv
|
| 86 |
+
.python-version
|
| 87 |
+
|
| 88 |
+
# pipenv
|
| 89 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
| 90 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
| 91 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
| 92 |
+
# install all needed dependencies.
|
| 93 |
+
#Pipfile.lock
|
| 94 |
+
|
| 95 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
|
| 96 |
+
__pypackages__/
|
| 97 |
+
|
| 98 |
+
# Celery stuff
|
| 99 |
+
celerybeat-schedule
|
| 100 |
+
celerybeat.pid
|
| 101 |
+
|
| 102 |
+
# SageMath parsed files
|
| 103 |
+
*.sage.py
|
| 104 |
+
|
| 105 |
+
# Environments
|
| 106 |
+
.venv
|
| 107 |
+
env/
|
| 108 |
+
venv/
|
| 109 |
+
ENV/
|
| 110 |
+
env.bak/
|
| 111 |
+
venv.bak/
|
| 112 |
+
|
| 113 |
+
# Spyder project settings
|
| 114 |
+
.spyderproject
|
| 115 |
+
.spyproject
|
| 116 |
+
|
| 117 |
+
# Rope project settings
|
| 118 |
+
.ropeproject
|
| 119 |
+
|
| 120 |
+
# mkdocs documentation
|
| 121 |
+
/site
|
| 122 |
+
|
| 123 |
+
# mypy
|
| 124 |
+
.mypy_cache/
|
| 125 |
+
.dmypy.json
|
| 126 |
+
dmypy.json
|
| 127 |
+
|
| 128 |
+
# Pyre type checker
|
| 129 |
+
.pyre/
|
| 130 |
+
|
| 131 |
+
### VisualStudioCode
|
| 132 |
+
.vscode/*
|
| 133 |
+
!.vscode/settings.json
|
| 134 |
+
!.vscode/tasks.json
|
| 135 |
+
!.vscode/launch.json
|
| 136 |
+
!.vscode/extensions.json
|
| 137 |
+
*.code-workspace
|
| 138 |
+
**/.vscode
|
| 139 |
+
|
| 140 |
+
# JetBrains
|
| 141 |
+
.idea/
|
| 142 |
+
|
| 143 |
+
# Data & Models
|
| 144 |
+
*.h5
|
| 145 |
+
*.tar
|
| 146 |
+
*.tar.gz
|
| 147 |
+
|
| 148 |
+
# Lightning-Hydra-Template
|
| 149 |
+
configs/local/default.yaml
|
| 150 |
+
/data/
|
| 151 |
+
/logs/
|
| 152 |
+
.env
|
| 153 |
+
|
| 154 |
+
# Aim logging
|
| 155 |
+
.aim
|
| 156 |
+
Fspen_an_Ultra-Lightweight_Network_for_Real_Time_Speech_Enahncment.pdf
|
| 157 |
+
voicebank_data
|
| 158 |
+
voicebank_wavs
|
fspen/.pre-commit-config.yaml
ADDED
|
@@ -0,0 +1,134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
default_language_version:
|
| 2 |
+
python: python3
|
| 3 |
+
|
| 4 |
+
repos:
|
| 5 |
+
- repo: https://github.com/pre-commit/pre-commit-hooks
|
| 6 |
+
rev: v6.0.0
|
| 7 |
+
hooks:
|
| 8 |
+
# list of supported hooks: https://pre-commit.com/hooks.html
|
| 9 |
+
- id: trailing-whitespace
|
| 10 |
+
- id: end-of-file-fixer
|
| 11 |
+
- id: check-docstring-first
|
| 12 |
+
- id: check-yaml
|
| 13 |
+
- id: debug-statements
|
| 14 |
+
- id: detect-private-key
|
| 15 |
+
- id: check-executables-have-shebangs
|
| 16 |
+
- id: check-toml
|
| 17 |
+
- id: check-case-conflict
|
| 18 |
+
- id: check-added-large-files
|
| 19 |
+
|
| 20 |
+
# python code formatting
|
| 21 |
+
- repo: https://github.com/psf/black
|
| 22 |
+
rev: 25.12.0
|
| 23 |
+
hooks:
|
| 24 |
+
- id: black
|
| 25 |
+
args: [--line-length, "99"]
|
| 26 |
+
|
| 27 |
+
# python import sorting
|
| 28 |
+
- repo: https://github.com/PyCQA/isort
|
| 29 |
+
rev: 7.0.0
|
| 30 |
+
hooks:
|
| 31 |
+
- id: isort
|
| 32 |
+
args: ["--profile", "black", "--filter-files"]
|
| 33 |
+
|
| 34 |
+
# python upgrading syntax to newer version
|
| 35 |
+
- repo: https://github.com/asottile/pyupgrade
|
| 36 |
+
rev: v3.21.2
|
| 37 |
+
hooks:
|
| 38 |
+
- id: pyupgrade
|
| 39 |
+
args: [--py38-plus]
|
| 40 |
+
|
| 41 |
+
# python docstring formatting
|
| 42 |
+
- repo: https://github.com/myint/docformatter
|
| 43 |
+
rev: v1.7.7
|
| 44 |
+
hooks:
|
| 45 |
+
- id: docformatter
|
| 46 |
+
args:
|
| 47 |
+
[
|
| 48 |
+
--in-place,
|
| 49 |
+
--wrap-summaries=99,
|
| 50 |
+
--wrap-descriptions=99,
|
| 51 |
+
--style=sphinx,
|
| 52 |
+
--black,
|
| 53 |
+
]
|
| 54 |
+
|
| 55 |
+
# # python docstring coverage checking
|
| 56 |
+
# - repo: https://github.com/econchick/interrogate
|
| 57 |
+
# rev: 1.7.0 # or master if you're bold
|
| 58 |
+
# hooks:
|
| 59 |
+
# - id: interrogate
|
| 60 |
+
# args:
|
| 61 |
+
# [
|
| 62 |
+
# --verbose,
|
| 63 |
+
# --fail-under=80,
|
| 64 |
+
# --ignore-init-module,
|
| 65 |
+
# --ignore-init-method,
|
| 66 |
+
# --ignore-module,
|
| 67 |
+
# --ignore-nested-functions,
|
| 68 |
+
# -vv,
|
| 69 |
+
# ]
|
| 70 |
+
|
| 71 |
+
# python check (PEP8), programming errors and code complexity
|
| 72 |
+
- repo: https://github.com/PyCQA/flake8
|
| 73 |
+
rev: 7.3.0
|
| 74 |
+
hooks:
|
| 75 |
+
- id: flake8
|
| 76 |
+
args:
|
| 77 |
+
[
|
| 78 |
+
"--extend-ignore",
|
| 79 |
+
"E203,E402,E501,F401,F841,RST2,RST301",
|
| 80 |
+
"--exclude",
|
| 81 |
+
"logs/*,data/*",
|
| 82 |
+
]
|
| 83 |
+
additional_dependencies: [flake8-rst-docstrings==0.3.0]
|
| 84 |
+
|
| 85 |
+
# python security linter
|
| 86 |
+
- repo: https://github.com/PyCQA/bandit
|
| 87 |
+
rev: "1.9.2"
|
| 88 |
+
hooks:
|
| 89 |
+
- id: bandit
|
| 90 |
+
args: ["-s", "B101,B311"]
|
| 91 |
+
|
| 92 |
+
# yaml formatting
|
| 93 |
+
- repo: https://github.com/pre-commit/mirrors-prettier
|
| 94 |
+
rev: v4.0.0-alpha.8
|
| 95 |
+
hooks:
|
| 96 |
+
- id: prettier
|
| 97 |
+
types: [yaml]
|
| 98 |
+
exclude: "environment.yaml"
|
| 99 |
+
|
| 100 |
+
# shell scripts linter
|
| 101 |
+
- repo: https://github.com/shellcheck-py/shellcheck-py
|
| 102 |
+
rev: v0.11.0.1
|
| 103 |
+
hooks:
|
| 104 |
+
- id: shellcheck
|
| 105 |
+
|
| 106 |
+
# word spelling linter
|
| 107 |
+
- repo: https://github.com/codespell-project/codespell
|
| 108 |
+
rev: v2.4.1
|
| 109 |
+
hooks:
|
| 110 |
+
- id: codespell
|
| 111 |
+
args:
|
| 112 |
+
- --skip=logs/**,data/**,*.ipynb
|
| 113 |
+
# - --ignore-words-list=abc,def
|
| 114 |
+
|
| 115 |
+
# jupyter notebook cell output clearing
|
| 116 |
+
- repo: https://github.com/kynan/nbstripout
|
| 117 |
+
rev: 0.8.2
|
| 118 |
+
hooks:
|
| 119 |
+
- id: nbstripout
|
| 120 |
+
|
| 121 |
+
# jupyter notebook linting
|
| 122 |
+
- repo: https://github.com/nbQA-dev/nbQA
|
| 123 |
+
rev: 1.9.1
|
| 124 |
+
hooks:
|
| 125 |
+
- id: nbqa-black
|
| 126 |
+
args: ["--line-length=99"]
|
| 127 |
+
- id: nbqa-isort
|
| 128 |
+
args: ["--profile=black"]
|
| 129 |
+
- id: nbqa-flake8
|
| 130 |
+
args:
|
| 131 |
+
[
|
| 132 |
+
"--extend-ignore=E203,E402,E501,F401,F841",
|
| 133 |
+
"--exclude=logs/*,data/*",
|
| 134 |
+
]
|
fspen/.project-root
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# this file is required for inferring the project root directory
|
| 2 |
+
# do not delete
|
fspen/Makefile
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
help: ## Show help
|
| 3 |
+
@grep -E '^[.a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-30s\033[0m %s\n", $$1, $$2}'
|
| 4 |
+
|
| 5 |
+
clean: ## Clean autogenerated files
|
| 6 |
+
rm -rf dist
|
| 7 |
+
find . -type f -name "*.DS_Store" -ls -delete
|
| 8 |
+
find . | grep -E "(__pycache__|\.pyc|\.pyo)" | xargs rm -rf
|
| 9 |
+
find . | grep -E ".pytest_cache" | xargs rm -rf
|
| 10 |
+
find . | grep -E ".ipynb_checkpoints" | xargs rm -rf
|
| 11 |
+
rm -f .coverage
|
| 12 |
+
|
| 13 |
+
clean-logs: ## Clean logs
|
| 14 |
+
rm -rf logs/**
|
| 15 |
+
|
| 16 |
+
format: ## Run pre-commit hooks
|
| 17 |
+
pre-commit run -a
|
| 18 |
+
|
| 19 |
+
sync: ## Merge changes from main branch to your current branch
|
| 20 |
+
git pull
|
| 21 |
+
git pull origin main
|
| 22 |
+
|
| 23 |
+
test: ## Run not slow tests
|
| 24 |
+
pytest -k "not slow"
|
| 25 |
+
|
| 26 |
+
test-full: ## Run all tests
|
| 27 |
+
pytest
|
| 28 |
+
|
| 29 |
+
train: ## Train the model
|
| 30 |
+
python src/train.py
|
fspen/README.md
ADDED
|
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<div align="center">
|
| 2 |
+
|
| 3 |
+
# Your Project Name
|
| 4 |
+
|
| 5 |
+
<a href="https://pytorch.org/get-started/locally/"><img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-ee4c2c?logo=pytorch&logoColor=white"></a>
|
| 6 |
+
<a href="https://pytorchlightning.ai/"><img alt="Lightning" src="https://img.shields.io/badge/-Lightning-792ee5?logo=pytorchlightning&logoColor=white"></a>
|
| 7 |
+
<a href="https://hydra.cc/"><img alt="Config: Hydra" src="https://img.shields.io/badge/Config-Hydra-89b8cd"></a>
|
| 8 |
+
<a href="https://github.com/ashleve/lightning-hydra-template"><img alt="Template" src="https://img.shields.io/badge/-Lightning--Hydra--Template-017F2F?style=flat&logo=github&labelColor=gray"></a><br>
|
| 9 |
+
[](https://ieeexplore.ieee.org/stamp/stamp.jsp?tp=&arnumber=10446016)
|
| 10 |
+
|
| 11 |
+
</div>
|
| 12 |
+
|
| 13 |
+
## Description
|
| 14 |
+
|
| 15 |
+
What it does
|
| 16 |
+
|
| 17 |
+
## Installation
|
| 18 |
+
|
| 19 |
+
#### Pip
|
| 20 |
+
|
| 21 |
+
```bash
|
| 22 |
+
# clone project
|
| 23 |
+
git clone https://github.com/iliasslasri/
|
| 24 |
+
cd
|
| 25 |
+
|
| 26 |
+
# [OPTIONAL] create conda environment
|
| 27 |
+
conda create -n myenv python=3.9
|
| 28 |
+
conda activate myenv
|
| 29 |
+
|
| 30 |
+
# install pytorch according to instructions
|
| 31 |
+
# https://pytorch.org/get-started/
|
| 32 |
+
|
| 33 |
+
# install requirements
|
| 34 |
+
pip install -r requirements.txt
|
| 35 |
+
```
|
| 36 |
+
|
| 37 |
+
#### Conda
|
| 38 |
+
|
| 39 |
+
```bash
|
| 40 |
+
# clone project
|
| 41 |
+
git clone https://github.com/YourGithubName/your-repo-name
|
| 42 |
+
cd your-repo-name
|
| 43 |
+
|
| 44 |
+
# create conda environment and install dependencies
|
| 45 |
+
conda env create -f environment.yaml -n myenv
|
| 46 |
+
|
| 47 |
+
# activate conda environment
|
| 48 |
+
conda activate myenv
|
| 49 |
+
```
|
| 50 |
+
|
| 51 |
+
## How to run
|
| 52 |
+
|
| 53 |
+
Train model with default configuration
|
| 54 |
+
|
| 55 |
+
```bash
|
| 56 |
+
# train on CPU
|
| 57 |
+
python src/train.py trainer=cpu
|
| 58 |
+
|
| 59 |
+
# train on GPU
|
| 60 |
+
python src/train.py trainer=gpu
|
| 61 |
+
```
|
| 62 |
+
|
| 63 |
+
Train model with chosen experiment configuration from [configs/experiment/](configs/experiment/)
|
| 64 |
+
|
| 65 |
+
```bash
|
| 66 |
+
python src/train.py experiment=experiment_name.yaml
|
| 67 |
+
```
|
| 68 |
+
|
| 69 |
+
You can override any parameter from command line like this
|
| 70 |
+
|
| 71 |
+
```bash
|
| 72 |
+
python src/train.py trainer.max_epochs=20 data.batch_size=64
|
| 73 |
+
```
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
## Set up environment for dev
|
| 79 |
+
```bash
|
| 80 |
+
pre-commit install
|
| 81 |
+
|
| 82 |
+
pip install \
|
| 83 |
+
"torch==2.0.1+cu118" \
|
| 84 |
+
"torchvision==0.15.2+cu118" \
|
| 85 |
+
"torchaudio==2.0.2+cu118" \
|
| 86 |
+
"lightning==2.0.9" \
|
| 87 |
+
"torchmetrics==0.11.4" \
|
| 88 |
+
"numpy<2.0" \
|
| 89 |
+
"pesq" \
|
| 90 |
+
"hydra-colorlog" \
|
| 91 |
+
--extra-index-url https://download.pytorch.org/whl/cu118
|
| 92 |
+
|
| 93 |
+
python3 src/train.py callbacks.rich_progress_bar=null
|
| 94 |
+
```
|
fspen/configs/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
# this file is needed here to include configs when building project as a package
|
fspen/configs/callbacks/default.yaml
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
defaults:
|
| 2 |
+
- model_checkpoint
|
| 3 |
+
- early_stopping
|
| 4 |
+
- model_summary
|
| 5 |
+
- rich_progress_bar
|
| 6 |
+
- _self_
|
| 7 |
+
|
| 8 |
+
model_checkpoint:
|
| 9 |
+
dirpath: ${paths.output_dir}/checkpoints
|
| 10 |
+
filename: "epoch_{epoch:03d}"
|
| 11 |
+
monitor: "val/loss"
|
| 12 |
+
mode: "min"
|
| 13 |
+
save_last: True
|
| 14 |
+
auto_insert_metric_name: False
|
| 15 |
+
save_top_k: 3
|
| 16 |
+
|
| 17 |
+
early_stopping:
|
| 18 |
+
monitor: "val/loss"
|
| 19 |
+
patience: 100
|
| 20 |
+
mode: "max"
|
| 21 |
+
|
| 22 |
+
model_summary:
|
| 23 |
+
max_depth: -1
|
fspen/configs/callbacks/early_stopping.yaml
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.EarlyStopping.html
|
| 2 |
+
|
| 3 |
+
early_stopping:
|
| 4 |
+
_target_: lightning.pytorch.callbacks.EarlyStopping
|
| 5 |
+
monitor: ??? # quantity to be monitored, must be specified !!!
|
| 6 |
+
min_delta: 0. # minimum change in the monitored quantity to qualify as an improvement
|
| 7 |
+
patience: 3 # number of checks with no improvement after which training will be stopped
|
| 8 |
+
verbose: False # verbosity mode
|
| 9 |
+
mode: "min" # "max" means higher metric value is better, can be also "min"
|
| 10 |
+
strict: True # whether to crash the training if monitor is not found in the validation metrics
|
| 11 |
+
check_finite: True # when set True, stops training when the monitor becomes NaN or infinite
|
| 12 |
+
stopping_threshold: null # stop training immediately once the monitored quantity reaches this threshold
|
| 13 |
+
divergence_threshold: null # stop training as soon as the monitored quantity becomes worse than this threshold
|
| 14 |
+
check_on_train_epoch_end: null # whether to run early stopping at the end of the training epoch
|
| 15 |
+
# log_rank_zero_only: False # this keyword argument isn't available in stable version
|
fspen/configs/callbacks/model_checkpoint.yaml
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.ModelCheckpoint.html
|
| 2 |
+
|
| 3 |
+
model_checkpoint:
|
| 4 |
+
_target_: lightning.pytorch.callbacks.ModelCheckpoint
|
| 5 |
+
dirpath: null # directory to save the model file
|
| 6 |
+
filename: null # checkpoint filename
|
| 7 |
+
monitor: null # name of the logged metric which determines when model is improving
|
| 8 |
+
verbose: False # verbosity mode
|
| 9 |
+
save_last: null # additionally always save an exact copy of the last checkpoint to a file last.ckpt
|
| 10 |
+
save_top_k: 1 # save k best models (determined by above metric)
|
| 11 |
+
mode: "min" # "max" means higher metric value is better, can be also "min"
|
| 12 |
+
auto_insert_metric_name: True # when True, the checkpoints filenames will contain the metric name
|
| 13 |
+
save_weights_only: False # if True, then only the model’s weights will be saved
|
| 14 |
+
every_n_train_steps: null # number of training steps between checkpoints
|
| 15 |
+
train_time_interval: null # checkpoints are monitored at the specified time interval
|
| 16 |
+
every_n_epochs: null # number of epochs between checkpoints
|
| 17 |
+
save_on_train_epoch_end: null # whether to run checkpointing at the end of the training epoch or the end of validation
|
fspen/configs/callbacks/model_summary.yaml
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.RichModelSummary.html
|
| 2 |
+
|
| 3 |
+
model_summary:
|
| 4 |
+
_target_: lightning.pytorch.callbacks.RichModelSummary
|
| 5 |
+
max_depth: 1 # the maximum depth of layer nesting that the summary will include
|
fspen/configs/callbacks/none.yaml
ADDED
|
File without changes
|
fspen/configs/callbacks/rich_progress_bar.yaml
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# https://lightning.ai/docs/pytorch/latest/api/lightning.pytorch.callbacks.RichProgressBar.html
|
| 2 |
+
|
| 3 |
+
rich_progress_bar:
|
| 4 |
+
_target_: lightning.pytorch.callbacks.RichProgressBar
|
fspen/configs/data/speech_enhancement.yaml
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
_target_: src.data.datamodule.SpeechEnhancementDataModule
|
| 2 |
+
dataset:
|
| 3 |
+
_target_: src.data.dataset.SpeechEnhancementDataset
|
| 4 |
+
sample_rate: 16000
|
| 5 |
+
segment_len: 10.0 # in seconds
|
| 6 |
+
n_fft: 512
|
| 7 |
+
hop_length: 128
|
| 8 |
+
win_length: 512
|
| 9 |
+
noisy_dir: ${paths.noisy_dir}
|
| 10 |
+
clean_dir: ${paths.clean_dir}
|
| 11 |
+
batch_size: 64 # Needs to be divisible by the number of devices (e.g., if in a distributed setup)
|
| 12 |
+
val_split: 0.1
|
| 13 |
+
num_workers: 2
|
fspen/configs/debug/default.yaml
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# @package _global_
|
| 2 |
+
|
| 3 |
+
# default debugging setup, runs 1 full epoch
|
| 4 |
+
# other debugging configs can inherit from this one
|
| 5 |
+
|
| 6 |
+
# overwrite task name so debugging logs are stored in separate folder
|
| 7 |
+
task_name: "debug"
|
| 8 |
+
|
| 9 |
+
# disable callbacks and loggers during debugging
|
| 10 |
+
callbacks: null
|
| 11 |
+
logger: null
|
| 12 |
+
|
| 13 |
+
extras:
|
| 14 |
+
ignore_warnings: False
|
| 15 |
+
enforce_tags: False
|
| 16 |
+
|
| 17 |
+
# sets level of all command line loggers to 'DEBUG'
|
| 18 |
+
# https://hydra.cc/docs/tutorials/basic/running_your_app/logging/
|
| 19 |
+
hydra:
|
| 20 |
+
job_logging:
|
| 21 |
+
root:
|
| 22 |
+
level: DEBUG
|
| 23 |
+
|
| 24 |
+
# use this to also set hydra loggers to 'DEBUG'
|
| 25 |
+
# verbose: True
|
| 26 |
+
|
| 27 |
+
trainer:
|
| 28 |
+
max_epochs: 1
|
| 29 |
+
accelerator: cpu # debuggers don't like gpus
|
| 30 |
+
devices: 1 # debuggers don't like multiprocessing
|
| 31 |
+
detect_anomaly: true # raise exception if NaN or +/-inf is detected in any tensor
|
| 32 |
+
|
| 33 |
+
data:
|
| 34 |
+
num_workers: 0 # debuggers don't like multiprocessing
|
| 35 |
+
pin_memory: False # disable gpu memory pin
|
fspen/configs/debug/fdr.yaml
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# @package _global_
|
| 2 |
+
|
| 3 |
+
# runs 1 train, 1 validation and 1 test step
|
| 4 |
+
|
| 5 |
+
defaults:
|
| 6 |
+
- default
|
| 7 |
+
|
| 8 |
+
trainer:
|
| 9 |
+
fast_dev_run: true
|
fspen/configs/debug/limit.yaml
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# @package _global_
|
| 2 |
+
|
| 3 |
+
# uses only 1% of the training data and 5% of validation/test data
|
| 4 |
+
|
| 5 |
+
defaults:
|
| 6 |
+
- default
|
| 7 |
+
|
| 8 |
+
trainer:
|
| 9 |
+
max_epochs: 3
|
| 10 |
+
limit_train_batches: 0.01
|
| 11 |
+
limit_val_batches: 0.05
|
| 12 |
+
limit_test_batches: 0.05
|
fspen/configs/debug/overfit.yaml
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# @package _global_
|
| 2 |
+
|
| 3 |
+
# overfits to 3 batches
|
| 4 |
+
|
| 5 |
+
defaults:
|
| 6 |
+
- default
|
| 7 |
+
|
| 8 |
+
trainer:
|
| 9 |
+
max_epochs: 20
|
| 10 |
+
overfit_batches: 3
|
| 11 |
+
|
| 12 |
+
# model ckpt and early stopping need to be disabled during overfitting
|
| 13 |
+
callbacks: null
|
fspen/configs/debug/profiler.yaml
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# @package _global_
|
| 2 |
+
|
| 3 |
+
# runs with execution time profiling
|
| 4 |
+
|
| 5 |
+
defaults:
|
| 6 |
+
- default
|
| 7 |
+
|
| 8 |
+
trainer:
|
| 9 |
+
max_epochs: 1
|
| 10 |
+
profiler: "simple"
|
| 11 |
+
# profiler: "advanced"
|
| 12 |
+
# profiler: "pytorch"
|
fspen/configs/eval.yaml
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# @package _global_
|
| 2 |
+
|
| 3 |
+
defaults:
|
| 4 |
+
- _self_
|
| 5 |
+
- data: speech_enhancement
|
| 6 |
+
- model: fspen
|
| 7 |
+
- callbacks: default
|
| 8 |
+
- logger: tensorboard
|
| 9 |
+
- trainer: gpu
|
| 10 |
+
- paths: eval
|
| 11 |
+
- extras: default
|
| 12 |
+
- hydra: default
|
| 13 |
+
|
| 14 |
+
task_name: "eval"
|
| 15 |
+
|
| 16 |
+
tags: ["dev"]
|
| 17 |
+
|
| 18 |
+
# passing checkpoint path is necessary for evaluation
|
| 19 |
+
ckpt_path: ???
|
fspen/configs/experiment/example.yaml
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# @package _global_
|
| 2 |
+
|
| 3 |
+
# to execute this experiment run:
|
| 4 |
+
# python train.py experiment=example
|
| 5 |
+
|
| 6 |
+
defaults:
|
| 7 |
+
- override /data: mnist
|
| 8 |
+
- override /model: mnist
|
| 9 |
+
- override /callbacks: default
|
| 10 |
+
- override /trainer: default
|
| 11 |
+
|
| 12 |
+
# all parameters below will be merged with parameters from default configurations set above
|
| 13 |
+
# this allows you to overwrite only specified parameters
|
| 14 |
+
|
| 15 |
+
tags: ["mnist", "simple_dense_net"]
|
| 16 |
+
|
| 17 |
+
seed: 12345
|
| 18 |
+
|
| 19 |
+
trainer:
|
| 20 |
+
min_epochs: 10
|
| 21 |
+
max_epochs: 10
|
| 22 |
+
gradient_clip_val: 0.5
|
| 23 |
+
|
| 24 |
+
model:
|
| 25 |
+
optimizer:
|
| 26 |
+
lr: 0.002
|
| 27 |
+
net:
|
| 28 |
+
lin1_size: 128
|
| 29 |
+
lin2_size: 256
|
| 30 |
+
lin3_size: 64
|
| 31 |
+
compile: false
|
| 32 |
+
|
| 33 |
+
data:
|
| 34 |
+
batch_size: 64
|
| 35 |
+
|
| 36 |
+
logger:
|
| 37 |
+
wandb:
|
| 38 |
+
tags: ${tags}
|
| 39 |
+
group: "mnist"
|
| 40 |
+
aim:
|
| 41 |
+
experiment: "mnist"
|
fspen/configs/extras/default.yaml
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# disable python warnings if they annoy you
|
| 2 |
+
ignore_warnings: False
|
| 3 |
+
|
| 4 |
+
# ask user for tags if none are provided in the config
|
| 5 |
+
enforce_tags: True
|
| 6 |
+
|
| 7 |
+
# pretty print config tree at the start of the run using Rich library
|
| 8 |
+
print_config: True
|
fspen/configs/hparams_search/mnist_optuna.yaml
ADDED
|
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# @package _global_
|
| 2 |
+
|
| 3 |
+
# example hyperparameter optimization of some experiment with Optuna:
|
| 4 |
+
# python train.py -m hparams_search=mnist_optuna experiment=example
|
| 5 |
+
|
| 6 |
+
defaults:
|
| 7 |
+
- override /hydra/sweeper: optuna
|
| 8 |
+
|
| 9 |
+
# choose metric which will be optimized by Optuna
|
| 10 |
+
# make sure this is the correct name of some metric logged in lightning module!
|
| 11 |
+
optimized_metric: "val/acc_best"
|
| 12 |
+
|
| 13 |
+
# here we define Optuna hyperparameter search
|
| 14 |
+
# it optimizes for value returned from function with @hydra.main decorator
|
| 15 |
+
# docs: https://hydra.cc/docs/next/plugins/optuna_sweeper
|
| 16 |
+
hydra:
|
| 17 |
+
mode: "MULTIRUN" # set hydra to multirun by default if this config is attached
|
| 18 |
+
|
| 19 |
+
sweeper:
|
| 20 |
+
_target_: hydra_plugins.hydra_optuna_sweeper.optuna_sweeper.OptunaSweeper
|
| 21 |
+
|
| 22 |
+
# storage URL to persist optimization results
|
| 23 |
+
# for example, you can use SQLite if you set 'sqlite:///example.db'
|
| 24 |
+
storage: null
|
| 25 |
+
|
| 26 |
+
# name of the study to persist optimization results
|
| 27 |
+
study_name: null
|
| 28 |
+
|
| 29 |
+
# number of parallel workers
|
| 30 |
+
n_jobs: 1
|
| 31 |
+
|
| 32 |
+
# 'minimize' or 'maximize' the objective
|
| 33 |
+
direction: maximize
|
| 34 |
+
|
| 35 |
+
# total number of runs that will be executed
|
| 36 |
+
n_trials: 20
|
| 37 |
+
|
| 38 |
+
# choose Optuna hyperparameter sampler
|
| 39 |
+
# you can choose bayesian sampler (tpe), random search (without optimization), grid sampler, and others
|
| 40 |
+
# docs: https://optuna.readthedocs.io/en/stable/reference/samplers.html
|
| 41 |
+
sampler:
|
| 42 |
+
_target_: optuna.samplers.TPESampler
|
| 43 |
+
seed: 1234
|
| 44 |
+
n_startup_trials: 10 # number of random sampling runs before optimization starts
|
| 45 |
+
|
| 46 |
+
# define hyperparameter search space
|
| 47 |
+
params:
|
| 48 |
+
model.optimizer.lr: interval(0.0001, 0.1)
|
| 49 |
+
data.batch_size: choice(32, 64, 128, 256)
|
| 50 |
+
model.net.lin1_size: choice(64, 128, 256)
|
| 51 |
+
model.net.lin2_size: choice(64, 128, 256)
|
| 52 |
+
model.net.lin3_size: choice(32, 64, 128, 256)
|
fspen/configs/hydra/default.yaml
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# https://hydra.cc/docs/configure_hydra/intro/
|
| 2 |
+
|
| 3 |
+
# enable color logging
|
| 4 |
+
defaults:
|
| 5 |
+
- override hydra_logging: colorlog
|
| 6 |
+
- override job_logging: colorlog
|
| 7 |
+
|
| 8 |
+
# output directory, generated dynamically on each run
|
| 9 |
+
run:
|
| 10 |
+
dir: ${paths.log_dir}/${task_name}/runs/${now:%Y-%m-%d}_${now:%H-%M-%S}
|
| 11 |
+
sweep:
|
| 12 |
+
dir: ${paths.log_dir}/${task_name}/multiruns/${now:%Y-%m-%d}_${now:%H-%M-%S}
|
| 13 |
+
subdir: ${hydra.job.num}
|
| 14 |
+
|
| 15 |
+
job_logging:
|
| 16 |
+
handlers:
|
| 17 |
+
file:
|
| 18 |
+
# Incorporates fix from https://github.com/facebookresearch/hydra/pull/2242
|
| 19 |
+
filename: ${hydra.runtime.output_dir}/${task_name}.log
|
fspen/configs/local/.gitkeep
ADDED
|
File without changes
|
fspen/configs/logger/aim.yaml
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# https://aimstack.io/
|
| 2 |
+
|
| 3 |
+
# example usage in lightning module:
|
| 4 |
+
# https://github.com/aimhubio/aim/blob/main/examples/pytorch_lightning_track.py
|
| 5 |
+
|
| 6 |
+
# open the Aim UI with the following command (run in the folder containing the `.aim` folder):
|
| 7 |
+
# `aim up`
|
| 8 |
+
|
| 9 |
+
aim:
|
| 10 |
+
_target_: aim.pytorch_lightning.AimLogger
|
| 11 |
+
repo: ${paths.root_dir} # .aim folder will be created here
|
| 12 |
+
# repo: "aim://ip_address:port" # can instead provide IP address pointing to Aim remote tracking server which manages the repo, see https://aimstack.readthedocs.io/en/latest/using/remote_tracking.html#
|
| 13 |
+
|
| 14 |
+
# aim allows to group runs under experiment name
|
| 15 |
+
experiment: null # any string, set to "default" if not specified
|
| 16 |
+
|
| 17 |
+
train_metric_prefix: "train/"
|
| 18 |
+
val_metric_prefix: "val/"
|
| 19 |
+
test_metric_prefix: "test/"
|
| 20 |
+
|
| 21 |
+
# sets the tracking interval in seconds for system usage metrics (CPU, GPU, memory, etc.)
|
| 22 |
+
system_tracking_interval: 10 # set to null to disable system metrics tracking
|
| 23 |
+
|
| 24 |
+
# enable/disable logging of system params such as installed packages, git info, env vars, etc.
|
| 25 |
+
log_system_params: true
|
| 26 |
+
|
| 27 |
+
# enable/disable tracking console logs (default value is true)
|
| 28 |
+
capture_terminal_logs: false # set to false to avoid infinite console log loop issue https://github.com/aimhubio/aim/issues/2550
|
fspen/configs/logger/comet.yaml
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# https://www.comet.ml
|
| 2 |
+
|
| 3 |
+
comet:
|
| 4 |
+
_target_: lightning.pytorch.loggers.comet.CometLogger
|
| 5 |
+
api_key: ${oc.env:COMET_API_TOKEN} # api key is loaded from environment variable
|
| 6 |
+
save_dir: "${paths.output_dir}"
|
| 7 |
+
project_name: "lightning-hydra-template"
|
| 8 |
+
rest_api_key: null
|
| 9 |
+
# experiment_name: ""
|
| 10 |
+
experiment_key: null # set to resume experiment
|
| 11 |
+
offline: False
|
| 12 |
+
prefix: ""
|
fspen/configs/logger/csv.yaml
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# csv logger built in lightning
|
| 2 |
+
|
| 3 |
+
csv:
|
| 4 |
+
_target_: lightning.pytorch.loggers.csv_logs.CSVLogger
|
| 5 |
+
save_dir: "${paths.output_dir}"
|
| 6 |
+
name: "csv/"
|
| 7 |
+
prefix: ""
|
fspen/configs/logger/many_loggers.yaml
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# train with many loggers at once
|
| 2 |
+
|
| 3 |
+
defaults:
|
| 4 |
+
# - comet
|
| 5 |
+
- csv
|
| 6 |
+
# - mlflow
|
| 7 |
+
# - neptune
|
| 8 |
+
- tensorboard
|
| 9 |
+
- wandb
|
fspen/configs/logger/mlflow.yaml
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# https://mlflow.org
|
| 2 |
+
|
| 3 |
+
mlflow:
|
| 4 |
+
_target_: lightning.pytorch.loggers.mlflow.MLFlowLogger
|
| 5 |
+
# experiment_name: ""
|
| 6 |
+
# run_name: ""
|
| 7 |
+
tracking_uri: ${paths.log_dir}/mlflow/mlruns # run `mlflow ui` command inside the `logs/mlflow/` dir to open the UI
|
| 8 |
+
tags: null
|
| 9 |
+
# save_dir: "./mlruns"
|
| 10 |
+
prefix: ""
|
| 11 |
+
artifact_location: null
|
| 12 |
+
# run_id: ""
|
fspen/configs/logger/neptune.yaml
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# https://neptune.ai
|
| 2 |
+
|
| 3 |
+
neptune:
|
| 4 |
+
_target_: lightning.pytorch.loggers.neptune.NeptuneLogger
|
| 5 |
+
api_key: ${oc.env:NEPTUNE_API_TOKEN} # api key is loaded from environment variable
|
| 6 |
+
project: username/lightning-hydra-template
|
| 7 |
+
# name: ""
|
| 8 |
+
log_model_checkpoints: True
|
| 9 |
+
prefix: ""
|
fspen/configs/logger/tensorboard.yaml
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# https://www.tensorflow.org/tensorboard/
|
| 2 |
+
|
| 3 |
+
tensorboard:
|
| 4 |
+
_target_: lightning.pytorch.loggers.tensorboard.TensorBoardLogger
|
| 5 |
+
save_dir: "${paths.output_dir}/tensorboard/"
|
| 6 |
+
name: null
|
| 7 |
+
log_graph: False
|
| 8 |
+
default_hp_metric: True
|
| 9 |
+
prefix: ""
|
| 10 |
+
# version: ""
|
fspen/configs/logger/wandb.yaml
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# https://wandb.ai
|
| 2 |
+
|
| 3 |
+
wandb:
|
| 4 |
+
_target_: lightning.pytorch.loggers.wandb.WandbLogger
|
| 5 |
+
# name: "" # name of the run (normally generated by wandb)
|
| 6 |
+
save_dir: "${paths.output_dir}"
|
| 7 |
+
offline: False
|
| 8 |
+
id: null # pass correct id to resume experiment!
|
| 9 |
+
anonymous: null # enable anonymous logging
|
| 10 |
+
project: "lightning-hydra-template"
|
| 11 |
+
log_model: False # upload lightning ckpts
|
| 12 |
+
prefix: "" # a string to put at the beginning of metric keys
|
| 13 |
+
# entity: "" # set to name of your wandb team
|
| 14 |
+
group: ""
|
| 15 |
+
tags: []
|
| 16 |
+
job_type: ""
|
fspen/configs/model/fspen.yaml
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
_target_: src.models.fspen_module.FSPENLitModule
|
| 2 |
+
|
| 3 |
+
optimizer:
|
| 4 |
+
_target_: torch.optim.Adam
|
| 5 |
+
_partial_: true
|
| 6 |
+
lr: 0.001
|
| 7 |
+
weight_decay: 0.0
|
| 8 |
+
|
| 9 |
+
scheduler:
|
| 10 |
+
_target_: torch.optim.lr_scheduler.ReduceLROnPlateau
|
| 11 |
+
_partial_: true
|
| 12 |
+
mode: min
|
| 13 |
+
factor: 0.1
|
| 14 |
+
patience: 10
|
| 15 |
+
|
| 16 |
+
net:
|
| 17 |
+
_target_: src.models.fspen.FSPEN
|
| 18 |
+
|
| 19 |
+
# compile model for faster training with pytorch 2.0
|
| 20 |
+
compile: false
|
| 21 |
+
|
| 22 |
+
criterion:
|
| 23 |
+
_target_: src.models.components.loss.MultiResolutionSTFTLoss
|
| 24 |
+
fft_sizes: [512, 1024, 2048]
|
fspen/configs/paths/default.yaml
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# path to root directory
|
| 2 |
+
# this requires PROJECT_ROOT environment variable to exist
|
| 3 |
+
# you can replace it with "." if you want the root to be the current working directory
|
| 4 |
+
root_dir: ${oc.env:PROJECT_ROOT}
|
| 5 |
+
|
| 6 |
+
# path to data directory
|
| 7 |
+
noisy_dir: ${paths.root_dir}/data/noisy/
|
| 8 |
+
clean_dir: ${paths.root_dir}/data/clean/
|
| 9 |
+
|
| 10 |
+
# path to logging directory
|
| 11 |
+
log_dir: ${paths.root_dir}/logs/
|
| 12 |
+
|
| 13 |
+
# path to output directory, created dynamically by hydra
|
| 14 |
+
# path generation pattern is specified in `configs/hydra/default.yaml`
|
| 15 |
+
# use it to store all files generated during the run, like ckpts and metrics
|
| 16 |
+
output_dir: ${hydra:runtime.output_dir}
|
| 17 |
+
|
| 18 |
+
# path to working directory
|
| 19 |
+
work_dir: ${hydra:runtime.cwd}
|
fspen/configs/paths/eval.yaml
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# path to root directory
|
| 2 |
+
# this requires PROJECT_ROOT environment variable to exist
|
| 3 |
+
# you can replace it with "." if you want the root to be the current working directory
|
| 4 |
+
root_dir: ${oc.env:PROJECT_ROOT}
|
| 5 |
+
|
| 6 |
+
# path to data directory
|
| 7 |
+
noisy_dir: ${paths.root_dir}/voicebank_wavs/test/noisy/
|
| 8 |
+
clean_dir: ${paths.root_dir}/voicebank_wavs/test/clean/
|
| 9 |
+
|
| 10 |
+
# path to logging directory
|
| 11 |
+
log_dir: ${paths.root_dir}/logs/
|
| 12 |
+
|
| 13 |
+
# path to output directory, created dynamically by hydra
|
| 14 |
+
# path generation pattern is specified in `configs/hydra/default.yaml`
|
| 15 |
+
# use it to store all files generated during the run, like ckpts and metrics
|
| 16 |
+
output_dir: ${hydra:runtime.output_dir}
|
| 17 |
+
|
| 18 |
+
# path to working directory
|
| 19 |
+
work_dir: ${hydra:runtime.cwd}
|
fspen/configs/train.yaml
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# @package _global_
|
| 2 |
+
|
| 3 |
+
# specify here default configuration
|
| 4 |
+
# order of defaults determines the order in which configs override each other
|
| 5 |
+
defaults:
|
| 6 |
+
- _self_
|
| 7 |
+
- data: speech_enhancement
|
| 8 |
+
- model: fspen
|
| 9 |
+
- callbacks: default
|
| 10 |
+
- logger: tensorboard
|
| 11 |
+
- trainer: gpu
|
| 12 |
+
- paths: default
|
| 13 |
+
- extras: default
|
| 14 |
+
- hydra: default
|
| 15 |
+
|
| 16 |
+
# experiment configs allow for version control of specific hyperparameters
|
| 17 |
+
# e.g. best hyperparameters for given model and datamodule
|
| 18 |
+
- experiment: null
|
| 19 |
+
|
| 20 |
+
# config for hyperparameter optimization
|
| 21 |
+
- hparams_search: null
|
| 22 |
+
|
| 23 |
+
# optional local config for machine/user specific settings
|
| 24 |
+
# it's optional since it doesn't need to exist and is excluded from version control
|
| 25 |
+
- optional local: default
|
| 26 |
+
|
| 27 |
+
# debugging config (enable through command line, e.g. `python train.py debug=default)
|
| 28 |
+
- debug: null
|
| 29 |
+
|
| 30 |
+
# task name, determines output directory path
|
| 31 |
+
task_name: "train"
|
| 32 |
+
|
| 33 |
+
# tags to help you identify your experiments
|
| 34 |
+
# you can overwrite this in experiment configs
|
| 35 |
+
# overwrite from command line with `python train.py tags="[first_tag, second_tag]"`
|
| 36 |
+
tags: ["dev"]
|
| 37 |
+
|
| 38 |
+
# set False to skip model training
|
| 39 |
+
train: True
|
| 40 |
+
|
| 41 |
+
# evaluate on test set, using best model weights achieved during training
|
| 42 |
+
# lightning chooses best weights based on the metric specified in checkpoint callback
|
| 43 |
+
test: False
|
| 44 |
+
|
| 45 |
+
# simply provide checkpoint path to resume training
|
| 46 |
+
ckpt_path: null
|
| 47 |
+
|
| 48 |
+
# seed for random number generators in pytorch, numpy and python.random
|
| 49 |
+
seed: null
|
fspen/configs/trainer/cpu.yaml
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
defaults:
|
| 2 |
+
- default
|
| 3 |
+
|
| 4 |
+
accelerator: cpu
|
| 5 |
+
devices: 1
|
fspen/configs/trainer/ddp.yaml
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
defaults:
|
| 2 |
+
- default
|
| 3 |
+
|
| 4 |
+
strategy: ddp
|
| 5 |
+
|
| 6 |
+
accelerator: gpu
|
| 7 |
+
devices: 4
|
| 8 |
+
num_nodes: 1
|
| 9 |
+
sync_batchnorm: True
|
fspen/configs/trainer/ddp_sim.yaml
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
defaults:
|
| 2 |
+
- default
|
| 3 |
+
|
| 4 |
+
# simulate DDP on CPU, useful for debugging
|
| 5 |
+
accelerator: cpu
|
| 6 |
+
devices: 2
|
| 7 |
+
strategy: ddp_spawn
|
fspen/configs/trainer/default.yaml
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
_target_: lightning.pytorch.trainer.Trainer
|
| 2 |
+
|
| 3 |
+
default_root_dir: ${paths.output_dir}
|
| 4 |
+
|
| 5 |
+
min_epochs: 1 # prevents early stopping
|
| 6 |
+
max_epochs: 100
|
| 7 |
+
|
| 8 |
+
accelerator: cpu
|
| 9 |
+
devices: 1
|
| 10 |
+
|
| 11 |
+
# mixed precision for extra speed-up
|
| 12 |
+
# precision: 16
|
| 13 |
+
|
| 14 |
+
# perform a validation loop every N training epochs
|
| 15 |
+
check_val_every_n_epoch: 1
|
| 16 |
+
|
| 17 |
+
# set True to to ensure deterministic results
|
| 18 |
+
# makes training slower but gives more reproducibility than just setting seeds
|
| 19 |
+
deterministic: False
|
fspen/configs/trainer/gpu.yaml
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
defaults:
|
| 2 |
+
- default
|
| 3 |
+
|
| 4 |
+
accelerator: gpu
|
| 5 |
+
devices: 1
|
fspen/configs/trainer/mps.yaml
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
defaults:
|
| 2 |
+
- default
|
| 3 |
+
|
| 4 |
+
accelerator: mps
|
| 5 |
+
devices: 1
|
fspen/environment.yaml
ADDED
|
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
##############
|
| 2 |
+
## This environment uses old versions of dl libraries because the training are done on an old GPU
|
| 3 |
+
## mainly torch==2.0.1+cu118
|
| 4 |
+
##############
|
| 5 |
+
name: fspen
|
| 6 |
+
channels:
|
| 7 |
+
- defaults
|
| 8 |
+
- https://repo.anaconda.com/pkgs/main
|
| 9 |
+
- https://repo.anaconda.com/pkgs/r
|
| 10 |
+
dependencies:
|
| 11 |
+
- _libgcc_mutex=0.1
|
| 12 |
+
- _openmp_mutex=5.1
|
| 13 |
+
- bzip2=1.0.8
|
| 14 |
+
- ca-certificates=2025.9.9
|
| 15 |
+
- expat=2.7.1
|
| 16 |
+
- ld_impl_linux-64=2.44
|
| 17 |
+
- libffi=3.4.4
|
| 18 |
+
- libgcc-ng=11.2.0
|
| 19 |
+
- libgomp=11.2.0
|
| 20 |
+
- libnsl=2.0.0
|
| 21 |
+
- libstdcxx-ng=11.2.0
|
| 22 |
+
- libuuid=1.41.5
|
| 23 |
+
- libxcb=1.17.0
|
| 24 |
+
- libzlib=1.3.1
|
| 25 |
+
- ncurses=6.5
|
| 26 |
+
- openssl=3.0.18
|
| 27 |
+
- pip=25.2
|
| 28 |
+
- pthread-stubs=0.3
|
| 29 |
+
- python=3.10.19
|
| 30 |
+
- readline=8.3
|
| 31 |
+
- setuptools=80.9.0
|
| 32 |
+
- sqlite=3.50.2
|
| 33 |
+
- tk=8.6.15
|
| 34 |
+
- wheel=0.45.1
|
| 35 |
+
- xorg-libx11=1.8.12
|
| 36 |
+
- xorg-libxau=1.0.12
|
| 37 |
+
- xorg-libxdmcp=1.1.5
|
| 38 |
+
- xorg-xorgproto=2024.1
|
| 39 |
+
- xz=5.6.4
|
| 40 |
+
- zlib=1.3.1
|
| 41 |
+
- pip:
|
| 42 |
+
- absl-py==2.3.1
|
| 43 |
+
- aiohappyeyeballs==2.6.1
|
| 44 |
+
- aiohttp==3.13.2
|
| 45 |
+
- aiosignal==1.4.0
|
| 46 |
+
- antlr4-python3-runtime==4.9.3
|
| 47 |
+
- async-timeout==5.0.1
|
| 48 |
+
- certifi==2025.10.5
|
| 49 |
+
- charset-normalizer==3.4.4
|
| 50 |
+
- cmake==3.25.0
|
| 51 |
+
- contourpy==1.3.2
|
| 52 |
+
- cycler==0.12.1
|
| 53 |
+
- decorator==5.2.1
|
| 54 |
+
- deprecate==1.0.5
|
| 55 |
+
- dotenv==0.9.9
|
| 56 |
+
- einops==0.8.1
|
| 57 |
+
- filelock==3.19.1
|
| 58 |
+
- fonttools==4.60.1
|
| 59 |
+
- frozenlist==1.8.0
|
| 60 |
+
- fsspec==2025.9.0
|
| 61 |
+
- grpcio==1.76.0
|
| 62 |
+
- h5py==3.15.1
|
| 63 |
+
- hf-xet==1.1.10
|
| 64 |
+
- huggingface-hub==0.35.3
|
| 65 |
+
- hydra-colorlog==1.2.0
|
| 66 |
+
- hydra-core==1.3.2
|
| 67 |
+
- hydra-optuna-sweeper==1.2.0
|
| 68 |
+
- idna==3.11
|
| 69 |
+
- indic-numtowords==1.1.0
|
| 70 |
+
- iniconfig==2.3.0
|
| 71 |
+
- ipdb==0.13.13
|
| 72 |
+
- jinja2==3.1.6
|
| 73 |
+
- kiwisolver==1.4.9
|
| 74 |
+
- lightning==2.3.0
|
| 75 |
+
- lightning-utilities==0.15.2
|
| 76 |
+
- lit==15.0.7
|
| 77 |
+
- markdown==3.9
|
| 78 |
+
- markupsafe==2.1.5
|
| 79 |
+
- matplotlib==3.10.7
|
| 80 |
+
- more-itertools==10.8.0
|
| 81 |
+
- mpmath==1.3.0
|
| 82 |
+
- multidict==6.7.0
|
| 83 |
+
- numpy==1.26.4
|
| 84 |
+
- omegaconf==2.3.0
|
| 85 |
+
- pandas==2.3.3
|
| 86 |
+
- pexpect==4.9.0
|
| 87 |
+
- pillow==12.0.0
|
| 88 |
+
- pluggy==1.6.0
|
| 89 |
+
- pre-commit==4.5.0
|
| 90 |
+
- propcache==0.4.1
|
| 91 |
+
- protobuf==6.33.0
|
| 92 |
+
- ptyprocess==0.7.0
|
| 93 |
+
- pygments==2.19.2
|
| 94 |
+
- pyparsing==3.2.5
|
| 95 |
+
- pytest==9.0.1
|
| 96 |
+
- python-dotenv==1.2.1
|
| 97 |
+
- pytz==2025.2
|
| 98 |
+
- pyyaml==6.0.3
|
| 99 |
+
- regex==2025.10.23
|
| 100 |
+
- rich==14.2.0
|
| 101 |
+
- rootutils==1.0.7
|
| 102 |
+
- safetensors==0.6.2
|
| 103 |
+
- scipy==1.15.3
|
| 104 |
+
- six==1.17.0
|
| 105 |
+
- soundfile==0.13.1
|
| 106 |
+
- sympy==1.14.0
|
| 107 |
+
- tensorboard==2.20.0
|
| 108 |
+
- tensorboard-data-server==0.7.2
|
| 109 |
+
- tokenizers==0.22.1
|
| 110 |
+
- torch==2.0.1+cu118
|
| 111 |
+
- torch-tb-profiler==0.4.3
|
| 112 |
+
- torchaudio==2.0.2+cu118
|
| 113 |
+
- torchinfo==1.8.0
|
| 114 |
+
- torchmetrics==0.11.4
|
| 115 |
+
- torchsummary==1.5.1
|
| 116 |
+
- tqdm==4.67.1
|
| 117 |
+
- transformers==4.57.1
|
| 118 |
+
- triton==2.0.0
|
| 119 |
+
- typing-extensions==4.15.0
|
| 120 |
+
- tzdata==2025.2
|
| 121 |
+
- urllib3==2.5.0
|
| 122 |
+
- werkzeug==3.1.3
|
| 123 |
+
- xxhash==3.6.0
|
| 124 |
+
- yarl==1.22.0
|
| 125 |
+
prefix: /home/infres/lasri-22/miniconda3/envs/fspen
|
fspen/notebooks/.gitkeep
ADDED
|
File without changes
|
fspen/pyproject.toml
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[tool.pytest.ini_options]
|
| 2 |
+
addopts = [
|
| 3 |
+
"--color=yes",
|
| 4 |
+
"--durations=0",
|
| 5 |
+
"--strict-markers",
|
| 6 |
+
"--doctest-modules",
|
| 7 |
+
]
|
| 8 |
+
filterwarnings = [
|
| 9 |
+
"ignore::DeprecationWarning",
|
| 10 |
+
"ignore::UserWarning",
|
| 11 |
+
]
|
| 12 |
+
log_cli = "True"
|
| 13 |
+
markers = [
|
| 14 |
+
"slow: slow tests",
|
| 15 |
+
]
|
| 16 |
+
minversion = "6.0"
|
| 17 |
+
testpaths = "tests/"
|
| 18 |
+
|
| 19 |
+
[tool.coverage.report]
|
| 20 |
+
exclude_lines = [
|
| 21 |
+
"pragma: nocover",
|
| 22 |
+
"raise NotImplementedError",
|
| 23 |
+
"raise NotImplementedError()",
|
| 24 |
+
"if __name__ == .__main__.:",
|
| 25 |
+
]
|
fspen/requirements.txt
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# --------- pytorch --------- #
|
| 2 |
+
torch>=2.0.0
|
| 3 |
+
torchvision>=0.15.0
|
| 4 |
+
lightning>=2.0.0
|
| 5 |
+
torchmetrics>=0.11.4
|
| 6 |
+
|
| 7 |
+
# --------- hydra --------- #
|
| 8 |
+
hydra-core==1.3.2
|
| 9 |
+
hydra-colorlog==1.2.0
|
| 10 |
+
hydra-optuna-sweeper==1.2.0
|
| 11 |
+
|
| 12 |
+
# --------- loggers --------- #
|
| 13 |
+
# wandb
|
| 14 |
+
# neptune-client
|
| 15 |
+
# mlflow
|
| 16 |
+
# comet-ml
|
| 17 |
+
# aim>=3.16.2 # no lower than 3.16.2, see https://github.com/aimhubio/aim/issues/2550
|
| 18 |
+
|
| 19 |
+
# --------- others --------- #
|
| 20 |
+
rootutils # standardizing the project root setup
|
| 21 |
+
pre-commit # hooks for applying linters on commit
|
| 22 |
+
rich # beautiful text formatting in terminal
|
| 23 |
+
pytest # tests
|
| 24 |
+
# sh # for running bash commands in some tests (linux/macos only)
|
fspen/scripts/schedule.sh
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
# Schedule execution of many runs
|
| 3 |
+
# Run from root folder with: bash scripts/schedule.sh
|
| 4 |
+
|
| 5 |
+
python src/train.py trainer.max_epochs=5 logger=csv
|
| 6 |
+
|
| 7 |
+
python src/train.py trainer.max_epochs=10 logger=csv
|
fspen/setup.py
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
|
| 3 |
+
from setuptools import find_packages, setup
|
| 4 |
+
|
| 5 |
+
setup(
|
| 6 |
+
name="src",
|
| 7 |
+
version="0.0.1",
|
| 8 |
+
description="Describe Your Cool Project",
|
| 9 |
+
author="",
|
| 10 |
+
author_email="",
|
| 11 |
+
url="https://github.com/user/project",
|
| 12 |
+
install_requires=["lightning", "hydra-core"],
|
| 13 |
+
packages=find_packages(),
|
| 14 |
+
# use this to customize global commands available in the terminal after installing the package
|
| 15 |
+
entry_points={
|
| 16 |
+
"console_scripts": [
|
| 17 |
+
"train_command = src.train:main",
|
| 18 |
+
"eval_command = src.eval:main",
|
| 19 |
+
]
|
| 20 |
+
},
|
| 21 |
+
)
|