Spaces:
Running on Zero
Running on Zero
Commit ·
2b2771c
1
Parent(s): fccca85
update scripts
Browse files- .gitattributes +35 -0
- README.md +28 -4
- calculate_metrics.py +182 -0
- config.yaml +8 -17
- data/augment.py +1 -1
- data/dataset.py +38 -53
- evaluation/README.md +0 -31
- evaluation/__init__.py +0 -0
- evaluation/metrics.py +0 -183
- inference.py +113 -0
- models/MelRNN.py +5 -6
- models/UNet.py +5 -6
- modules/generator/ConvNeXt2DBlock.py +1 -1
- train.py +23 -85
- unwrap.py +7 -20
.gitattributes
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
| 5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
| 12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
| 13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
| 25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
| 28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
| 29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
README.md
CHANGED
|
@@ -21,6 +21,34 @@ The repository is organized to separate concerns, making it easy to extend and m
|
|
| 21 |
- `discriminator/` <- Discriminator architectures
|
| 22 |
- `generator/` <- Reusable generator components
|
| 23 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 24 |
## 🚀 Getting Started
|
| 25 |
|
| 26 |
### 1. Setup
|
|
@@ -43,8 +71,6 @@ Key sections to update:
|
|
| 43 |
|
| 44 |
- `data.train_dataset.root_directory`: Path to your training data.
|
| 45 |
- `data.train_dataset.file_list`: Path to a `.txt` file listing your training samples.
|
| 46 |
-
- `data.val_dataset.root_directory`: Path to your validation data.
|
| 47 |
-
- `data.val_dataset.file_list`: Path to a `.txt` file listing your validation samples.
|
| 48 |
- `model`: Choose the generator model and its parameters.
|
| 49 |
- `discriminators`: Add and configure one or more discriminators.
|
| 50 |
- `trainer`: Set training parameters like `max_steps`, `devices` (GPU IDs), and `precision`.
|
|
@@ -57,8 +83,6 @@ Launch the training process using the `train.py` script and your configuration f
|
|
| 57 |
python train.py --config config.yaml
|
| 58 |
```
|
| 59 |
|
| 60 |
-
Logs, checkpoints, and audio samples will be saved in the `lightning_logs/` directory.
|
| 61 |
-
|
| 62 |
### 4. Unwrap Generator Weights
|
| 63 |
|
| 64 |
After training, you may want to use the generator model for inference without the rest of the Lightning module. The `unwrap.py` script extracts the generator's `state_dict` from a checkpoint file.
|
|
|
|
| 21 |
- `discriminator/` <- Discriminator architectures
|
| 22 |
- `generator/` <- Reusable generator components
|
| 23 |
|
| 24 |
+
## Run Inference On The Pretrained Models
|
| 25 |
+
|
| 26 |
+
Download from https://huggingface.co/yongyizang/MSRChallengeBaseline, then run `inference.py` to evaluate the pretrained models.
|
| 27 |
+
|
| 28 |
+
```bash
|
| 29 |
+
python inference.py --config config.yaml --checkpoint path/to/your/checkpoint.ckpt --input_dir path/to/your/input/directory --output_dir path/to/your/output/directory
|
| 30 |
+
```
|
| 31 |
+
|
| 32 |
+
Every `*.flac` file in the `input_dir` will be processed and saved in the `output_dir`.
|
| 33 |
+
|
| 34 |
+
## Evaluation Script
|
| 35 |
+
|
| 36 |
+
Evaluation script is provided in the `calculate_metrics.py` file.
|
| 37 |
+
|
| 38 |
+
```bash
|
| 39 |
+
python calculate_metrics.py {file list}
|
| 40 |
+
```
|
| 41 |
+
|
| 42 |
+
The evaluation script is expecting a file list with each line in the format of `{target path}|{output path}`. Results will be printed to the console; you can use ` .. > output.txt` to redirect the output to a file.
|
| 43 |
+
|
| 44 |
+
We recommend modifying this script to fit your needs.
|
| 45 |
+
|
| 46 |
+
---
|
| 47 |
+
|
| 48 |
+
For a comprehensive list of arguments, please check each individual script.
|
| 49 |
+
|
| 50 |
+
---
|
| 51 |
+
|
| 52 |
## 🚀 Getting Started
|
| 53 |
|
| 54 |
### 1. Setup
|
|
|
|
| 71 |
|
| 72 |
- `data.train_dataset.root_directory`: Path to your training data.
|
| 73 |
- `data.train_dataset.file_list`: Path to a `.txt` file listing your training samples.
|
|
|
|
|
|
|
| 74 |
- `model`: Choose the generator model and its parameters.
|
| 75 |
- `discriminators`: Add and configure one or more discriminators.
|
| 76 |
- `trainer`: Set training parameters like `max_steps`, `devices` (GPU IDs), and `precision`.
|
|
|
|
| 83 |
python train.py --config config.yaml
|
| 84 |
```
|
| 85 |
|
|
|
|
|
|
|
| 86 |
### 4. Unwrap Generator Weights
|
| 87 |
|
| 88 |
After training, you may want to use the generator model for inference without the rest of the Lightning module. The `unwrap.py` script extracts the generator's `state_dict` from a checkpoint file.
|
calculate_metrics.py
ADDED
|
@@ -0,0 +1,182 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import soundfile as sf
|
| 3 |
+
import torch
|
| 4 |
+
from torchmetrics.audio import ScaleInvariantSignalNoiseRatio
|
| 5 |
+
import argparse
|
| 6 |
+
import numpy as np
|
| 7 |
+
import warnings
|
| 8 |
+
from scipy.linalg import sqrtm
|
| 9 |
+
from tqdm import tqdm
|
| 10 |
+
|
| 11 |
+
warnings.filterwarnings("ignore")
|
| 12 |
+
|
| 13 |
+
try:
|
| 14 |
+
from transformers import ClapModel, ClapProcessor
|
| 15 |
+
except ImportError:
|
| 16 |
+
print("Error: The 'transformers' library is not installed.")
|
| 17 |
+
print("Please install it to run FAD-CLAP calculations:")
|
| 18 |
+
print("pip install torch transformers")
|
| 19 |
+
exit(1)
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def load_audio(file_path, sr=48000):
|
| 23 |
+
try:
|
| 24 |
+
wav, samplerate = sf.read(file_path)
|
| 25 |
+
if samplerate != sr:
|
| 26 |
+
pass
|
| 27 |
+
if wav.ndim > 1:
|
| 28 |
+
wav = wav.T
|
| 29 |
+
else:
|
| 30 |
+
wav = wav[np.newaxis, :]
|
| 31 |
+
return torch.from_numpy(wav).float()
|
| 32 |
+
except Exception:
|
| 33 |
+
return None
|
| 34 |
+
|
| 35 |
+
def get_clap_embeddings(file_paths, model, processor, device, batch_size=16):
|
| 36 |
+
model.to(device)
|
| 37 |
+
all_embeddings = []
|
| 38 |
+
|
| 39 |
+
for i in tqdm(range(0, len(file_paths), batch_size), desc=" Calculating embeddings", ncols=100, leave=False):
|
| 40 |
+
batch_paths = file_paths[i:i+batch_size]
|
| 41 |
+
audio_batch = []
|
| 42 |
+
for path in batch_paths:
|
| 43 |
+
try:
|
| 44 |
+
wav, sr = sf.read(path)
|
| 45 |
+
if wav.ndim == 2 and wav.shape[1] == 2:
|
| 46 |
+
audio_batch.append(wav[:, 0]) # Left channel
|
| 47 |
+
audio_batch.append(wav[:, 1]) # Right channel
|
| 48 |
+
elif wav.ndim == 1:
|
| 49 |
+
audio_batch.append(wav)
|
| 50 |
+
else:
|
| 51 |
+
continue
|
| 52 |
+
except Exception:
|
| 53 |
+
continue
|
| 54 |
+
|
| 55 |
+
if not audio_batch:
|
| 56 |
+
continue
|
| 57 |
+
|
| 58 |
+
try:
|
| 59 |
+
inputs = processor(audios=audio_batch, sampling_rate=48000, return_tensors="pt", padding=True)
|
| 60 |
+
inputs = {key: val.to(device) for key, val in inputs.items()}
|
| 61 |
+
|
| 62 |
+
with torch.no_grad():
|
| 63 |
+
audio_features = model.get_audio_features(**inputs)
|
| 64 |
+
|
| 65 |
+
all_embeddings.append(audio_features.cpu().numpy())
|
| 66 |
+
except Exception:
|
| 67 |
+
continue
|
| 68 |
+
|
| 69 |
+
if not all_embeddings:
|
| 70 |
+
return np.array([])
|
| 71 |
+
|
| 72 |
+
return np.concatenate(all_embeddings, axis=0)
|
| 73 |
+
|
| 74 |
+
def calculate_frechet_distance(embeddings1, embeddings2):
|
| 75 |
+
if embeddings1.shape[0] < 2 or embeddings2.shape[0] < 2:
|
| 76 |
+
return None
|
| 77 |
+
|
| 78 |
+
mu1, mu2 = np.mean(embeddings1, axis=0), np.mean(embeddings2, axis=0)
|
| 79 |
+
sigma1, sigma2 = np.cov(embeddings1, rowvar=False), np.cov(embeddings2, rowvar=False)
|
| 80 |
+
|
| 81 |
+
ssdiff = np.sum((mu1 - mu2)**2.0)
|
| 82 |
+
|
| 83 |
+
try:
|
| 84 |
+
covmean, _ = sqrtm(sigma1.dot(sigma2), disp=False)
|
| 85 |
+
except Exception:
|
| 86 |
+
return None
|
| 87 |
+
|
| 88 |
+
if np.iscomplexobj(covmean):
|
| 89 |
+
covmean = covmean.real
|
| 90 |
+
|
| 91 |
+
fad_score = ssdiff + np.trace(sigma1 + sigma2 - 2.0 * covmean)
|
| 92 |
+
return fad_score
|
| 93 |
+
|
| 94 |
+
def main():
|
| 95 |
+
parser = argparse.ArgumentParser(description="Calculate SI-SNR and FAD-CLAP for audio pairs listed in a text file.")
|
| 96 |
+
parser.add_argument("file_list", type=str, help="Path to a text file with the format: target_path|output_path")
|
| 97 |
+
parser.add_argument("--batch_size", type=int, default=16, help="Batch size for FAD-CLAP embedding calculation.")
|
| 98 |
+
args = parser.parse_args()
|
| 99 |
+
|
| 100 |
+
if not os.path.exists(args.file_list):
|
| 101 |
+
print(f"Error: Input file not found at {args.file_list}")
|
| 102 |
+
return
|
| 103 |
+
|
| 104 |
+
sisnr_calculator = ScaleInvariantSignalNoiseRatio()
|
| 105 |
+
all_target_paths = []
|
| 106 |
+
all_output_paths = []
|
| 107 |
+
|
| 108 |
+
print("--- Calculating SI-SNR for each pair ---")
|
| 109 |
+
with open(args.file_list, 'r') as f:
|
| 110 |
+
for line in f:
|
| 111 |
+
line = line.strip()
|
| 112 |
+
if not line or '|' not in line:
|
| 113 |
+
continue
|
| 114 |
+
|
| 115 |
+
try:
|
| 116 |
+
target_path, output_path = [p.strip() for p in line.split('|')]
|
| 117 |
+
|
| 118 |
+
if not os.path.exists(target_path) or not os.path.exists(output_path):
|
| 119 |
+
print(f"Skipping line, file not found: {line}")
|
| 120 |
+
continue
|
| 121 |
+
|
| 122 |
+
target_wav = load_audio(target_path)
|
| 123 |
+
output_wav = load_audio(output_path)
|
| 124 |
+
|
| 125 |
+
if target_wav is None or output_wav is None:
|
| 126 |
+
continue
|
| 127 |
+
if target_wav.shape[0] != output_wav.shape[0]:
|
| 128 |
+
continue
|
| 129 |
+
|
| 130 |
+
min_len = min(target_wav.shape[-1], output_wav.shape[-1])
|
| 131 |
+
target_wav = target_wav[..., :min_len]
|
| 132 |
+
output_wav = output_wav[..., :min_len]
|
| 133 |
+
|
| 134 |
+
if target_wav.shape[-1] == 0:
|
| 135 |
+
continue
|
| 136 |
+
|
| 137 |
+
sisnr_val = sisnr_calculator(output_wav, target_wav)
|
| 138 |
+
print(f"{target_path}|{output_path}|{sisnr_val.item():.4f}")
|
| 139 |
+
|
| 140 |
+
all_target_paths.append(target_path)
|
| 141 |
+
all_output_paths.append(output_path)
|
| 142 |
+
|
| 143 |
+
except Exception:
|
| 144 |
+
continue
|
| 145 |
+
|
| 146 |
+
print("\n--- Calculating FAD-CLAP for all target vs. all output files ---")
|
| 147 |
+
if not all_target_paths:
|
| 148 |
+
print("No valid file pairs found to calculate FAD-CLAP.")
|
| 149 |
+
return
|
| 150 |
+
|
| 151 |
+
try:
|
| 152 |
+
print("Loading CLAP model...")
|
| 153 |
+
clap_model = ClapModel.from_pretrained("laion/clap-htsat-unfused")
|
| 154 |
+
clap_processor = ClapProcessor.from_pretrained("laion/clap-htsat-unfused")
|
| 155 |
+
clap_model.eval()
|
| 156 |
+
print("CLAP model loaded successfully.")
|
| 157 |
+
except Exception as e:
|
| 158 |
+
print(f"Fatal Error: Could not load CLAP model. Please check internet connection. Error: {e}")
|
| 159 |
+
return
|
| 160 |
+
|
| 161 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 162 |
+
print(f"Using device: {device}")
|
| 163 |
+
|
| 164 |
+
print("\nCalculating embeddings for all target files...")
|
| 165 |
+
target_embeddings = get_clap_embeddings(all_target_paths, clap_model, clap_processor, device, args.batch_size)
|
| 166 |
+
|
| 167 |
+
print("Calculating embeddings for all output files...")
|
| 168 |
+
output_embeddings = get_clap_embeddings(all_output_paths, clap_model, clap_processor, device, args.batch_size)
|
| 169 |
+
|
| 170 |
+
if target_embeddings.size > 0 and output_embeddings.size > 0:
|
| 171 |
+
print("Calculating Frechet Audio Distance (FAD)...")
|
| 172 |
+
fad_score = calculate_frechet_distance(target_embeddings, output_embeddings)
|
| 173 |
+
if fad_score is not None:
|
| 174 |
+
print(f"\nOverall FAD-CLAP Score: {fad_score:.4f}")
|
| 175 |
+
else:
|
| 176 |
+
print("\nCould not calculate FAD-CLAP score.")
|
| 177 |
+
else:
|
| 178 |
+
print("\nCould not calculate FAD-CLAP due to issues with embedding generation.")
|
| 179 |
+
|
| 180 |
+
if __name__ == "__main__":
|
| 181 |
+
main()
|
| 182 |
+
|
config.yaml
CHANGED
|
@@ -9,32 +9,24 @@ model:
|
|
| 9 |
num_heads: 4
|
| 10 |
window_size: 2048
|
| 11 |
hop_size: 512
|
| 12 |
-
sample_rate:
|
| 13 |
|
| 14 |
discriminators:
|
| 15 |
- name: "MultiFrequencyDiscriminator"
|
| 16 |
params:
|
| 17 |
nch: 1
|
| 18 |
window_sizes: [2048, 1024, 512]
|
| 19 |
-
sample_rate:
|
| 20 |
norm: True
|
| 21 |
-
# you can add more discriminators here
|
| 22 |
|
| 23 |
data:
|
| 24 |
-
sample_rate:
|
| 25 |
clip_duration: 3.0
|
| 26 |
train_dataset:
|
| 27 |
target_stem: "Voc"
|
| 28 |
-
root_directory: "/path/to/your/training/data"
|
| 29 |
-
file_list: "/path/to/your/train_split.txt"
|
| 30 |
apply_augmentation: True
|
| 31 |
snr_range: [0.0, 10.0]
|
| 32 |
-
val_dataset:
|
| 33 |
-
target_stem: "Voc"
|
| 34 |
-
root_directory: "/path/to/your/validation/data"
|
| 35 |
-
file_list: "/path/to/your/val_split.txt"
|
| 36 |
-
apply_augmentation: True
|
| 37 |
-
snr_range: [5.0, 5.0] # Fixed SNR for validation
|
| 38 |
dataloader_params:
|
| 39 |
batch_size: 4
|
| 40 |
num_workers: 8
|
|
@@ -56,15 +48,14 @@ losses:
|
|
| 56 |
lambda_feat: 2.0
|
| 57 |
lambda_gan: 1.0
|
| 58 |
reconstruction_loss:
|
| 59 |
-
sample_rate:
|
| 60 |
n_fft: [1024, 2048, 512]
|
| 61 |
hop_length: [256, 512, 128]
|
| 62 |
n_mels: [80, 160, 40]
|
| 63 |
|
| 64 |
trainer:
|
| 65 |
max_steps: 1000000
|
| 66 |
-
val_check_interval: 5000
|
| 67 |
log_every_n_steps: 100
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
|
|
|
| 9 |
num_heads: 4
|
| 10 |
window_size: 2048
|
| 11 |
hop_size: 512
|
| 12 |
+
sample_rate: 48000
|
| 13 |
|
| 14 |
discriminators:
|
| 15 |
- name: "MultiFrequencyDiscriminator"
|
| 16 |
params:
|
| 17 |
nch: 1
|
| 18 |
window_sizes: [2048, 1024, 512]
|
| 19 |
+
sample_rate: 48000
|
| 20 |
norm: True
|
|
|
|
| 21 |
|
| 22 |
data:
|
| 23 |
+
sample_rate: 48000
|
| 24 |
clip_duration: 3.0
|
| 25 |
train_dataset:
|
| 26 |
target_stem: "Voc"
|
| 27 |
+
root_directory: "/path/to/your/training/data/dir"
|
|
|
|
| 28 |
apply_augmentation: True
|
| 29 |
snr_range: [0.0, 10.0]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 30 |
dataloader_params:
|
| 31 |
batch_size: 4
|
| 32 |
num_workers: 8
|
|
|
|
| 48 |
lambda_feat: 2.0
|
| 49 |
lambda_gan: 1.0
|
| 50 |
reconstruction_loss:
|
| 51 |
+
sample_rate: 48000
|
| 52 |
n_fft: [1024, 2048, 512]
|
| 53 |
hop_length: [256, 512, 128]
|
| 54 |
n_mels: [80, 160, 40]
|
| 55 |
|
| 56 |
trainer:
|
| 57 |
max_steps: 1000000
|
|
|
|
| 58 |
log_every_n_steps: 100
|
| 59 |
+
checkpoint_save_interval: 10000
|
| 60 |
+
devices: [0]
|
| 61 |
+
precision: bf16-mixed
|
data/augment.py
CHANGED
|
@@ -1,5 +1,5 @@
|
|
| 1 |
import numpy as np
|
| 2 |
-
from eq_utils import apply_random_eq
|
| 3 |
from pedalboard import Pedalboard, Resample, Compressor, Distortion, Reverb, Limiter, MP3Compressor
|
| 4 |
|
| 5 |
def fix_length_to_duration(target: np.ndarray, duration: float) -> np.ndarray:
|
|
|
|
| 1 |
import numpy as np
|
| 2 |
+
from data.eq_utils import apply_random_eq
|
| 3 |
from pedalboard import Pedalboard, Resample, Compressor, Distortion, Reverb, Limiter, MP3Compressor
|
| 4 |
|
| 5 |
def fix_length_to_duration(target: np.ndarray, duration: float) -> np.ndarray:
|
data/dataset.py
CHANGED
|
@@ -8,7 +8,7 @@ import json
|
|
| 8 |
from typing import List, Optional, Dict, Union, Tuple, Any
|
| 9 |
from torch.utils.data import Dataset, Sampler
|
| 10 |
from tqdm import tqdm
|
| 11 |
-
from augment import StemAugmentation, MixtureAugmentation
|
| 12 |
|
| 13 |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
| 14 |
logger = logging.getLogger(__name__)
|
|
@@ -84,8 +84,7 @@ class RawStems(Dataset):
|
|
| 84 |
self,
|
| 85 |
target_stem: str,
|
| 86 |
root_directory: Union[str, Path],
|
| 87 |
-
|
| 88 |
-
sr: int = 44100,
|
| 89 |
clip_duration: float = 3.0,
|
| 90 |
snr_range: Tuple[float, float] = (0.0, 10.0),
|
| 91 |
apply_augmentation: bool = True,
|
|
@@ -97,18 +96,27 @@ class RawStems(Dataset):
|
|
| 97 |
self.snr_range = snr_range
|
| 98 |
self.apply_augmentation = apply_augmentation
|
| 99 |
self.rms_threshold = rms_threshold
|
| 100 |
-
|
| 101 |
-
self.folders = []
|
| 102 |
-
with open(file_list, 'r') as f:
|
| 103 |
-
for line in f:
|
| 104 |
-
folder = self.root_directory / Path(line.strip())
|
| 105 |
-
if folder.exists(): self.folders.append(folder)
|
| 106 |
-
else: logger.warning(f"Folder does not exist: {folder}")
|
| 107 |
|
| 108 |
target_stem_parts = target_stem.split("_")
|
| 109 |
self.target_stem_1 = target_stem_parts[0].strip()
|
| 110 |
self.target_stem_2 = target_stem_parts[1].strip() if len(target_stem_parts) > 1 else None
|
| 111 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 112 |
self.audio_files = self._index_audio_files()
|
| 113 |
if not self.audio_files: raise ValueError("No audio files found.")
|
| 114 |
|
|
@@ -150,12 +158,10 @@ class RawStems(Dataset):
|
|
| 150 |
activity_masks[path_str] = np.array([False] * len(rms_values))
|
| 151 |
continue
|
| 152 |
|
| 153 |
-
# Efficiently check if the average RMS in a sliding window is above the threshold
|
| 154 |
is_loud = rms_values > self.rms_threshold
|
| 155 |
sum_loud = np.convolve(is_loud, np.ones(window_size), 'valid')
|
| 156 |
-
avg_loud_enough = sum_loud / window_size > 0.8
|
| 157 |
|
| 158 |
-
# Pad the mask to match the original length of rms_values
|
| 159 |
mask = np.zeros(len(rms_values), dtype=bool)
|
| 160 |
mask[:len(avg_loud_enough)] = avg_loud_enough
|
| 161 |
activity_masks[path_str] = mask
|
|
@@ -171,13 +177,12 @@ class RawStems(Dataset):
|
|
| 171 |
for file_path in file_paths:
|
| 172 |
path_str = str(file_path.relative_to(self.root_directory))
|
| 173 |
mask = self.activity_masks.get(path_str)
|
| 174 |
-
if mask is None: return []
|
| 175 |
masks_to_intersect.append(mask)
|
| 176 |
min_len = min(min_len, len(mask))
|
| 177 |
|
| 178 |
if not masks_to_intersect: return []
|
| 179 |
|
| 180 |
-
# Truncate all masks to the minimum length and intersect
|
| 181 |
final_mask = np.ones(min_len, dtype=bool)
|
| 182 |
for mask in masks_to_intersect:
|
| 183 |
final_mask &= mask[:min_len]
|
|
@@ -204,7 +209,7 @@ class RawStems(Dataset):
|
|
| 204 |
if not is_target:
|
| 205 |
song_dict["others"].append(p)
|
| 206 |
except ValueError:
|
| 207 |
-
continue
|
| 208 |
|
| 209 |
if song_dict["target_stems"] and song_dict["others"]:
|
| 210 |
indexed_songs.append(song_dict)
|
|
@@ -226,12 +231,11 @@ class RawStems(Dataset):
|
|
| 226 |
start_second = random.choice(valid_starts)
|
| 227 |
offset = start_second + random.uniform(0, 1.0 - (self.clip_duration % 1.0 or 1.0))
|
| 228 |
|
| 229 |
-
# --- Audio Loading and Mixing ---
|
| 230 |
target_mix = sum(load_audio(p, offset, self.clip_duration, self.sr) for p in selected_targets) / num_targets
|
| 231 |
other_mix = sum(load_audio(p, offset, self.clip_duration, self.sr) for p in selected_others) / num_others
|
| 232 |
|
| 233 |
if not contains_audio_signal(target_mix) or not contains_audio_signal(other_mix):
|
| 234 |
-
continue
|
| 235 |
|
| 236 |
target_clean = target_mix.copy()
|
| 237 |
target_augmented = self.stem_augmentation.apply(target_mix, self.sr) if self.apply_augmentation else target_mix
|
|
@@ -243,16 +247,28 @@ class RawStems(Dataset):
|
|
| 243 |
|
| 244 |
mixture_augmented = self.mixture_augmentation.apply(mixture, self.sr) if self.apply_augmentation else mixture
|
| 245 |
|
| 246 |
-
# --- Normalization and final prep ---
|
| 247 |
max_val = np.max(np.abs(mixture_augmented)) + 1e-8
|
| 248 |
mixture_final = mixture_augmented / max_val
|
| 249 |
target_final = target_clean / max_val
|
| 250 |
|
| 251 |
rescale = np.random.uniform(*DEFAULT_GAIN_RANGE)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 252 |
|
| 253 |
return {
|
| 254 |
-
"mixture": np.nan_to_num(
|
| 255 |
-
"target": np.nan_to_num(
|
| 256 |
}
|
| 257 |
|
| 258 |
return self.__getitem__(random.randint(0, len(self.audio_files) - 1))
|
|
@@ -275,35 +291,4 @@ class InfiniteSampler(Sampler):
|
|
| 275 |
while True:
|
| 276 |
if self.pointer >= self.dataset_size: self.reset()
|
| 277 |
yield self.indexes[self.pointer]
|
| 278 |
-
self.pointer += 1
|
| 279 |
-
|
| 280 |
-
if __name__ == "__main__":
|
| 281 |
-
root = "/lan/ifc/downloaded_datasets/cambridge-mt/sorted_files"
|
| 282 |
-
dataset = RawStems(
|
| 283 |
-
target_stem="Voc",
|
| 284 |
-
root_directory=root,
|
| 285 |
-
file_list="/home/yongyizang/music_source_restoration/configs/data_split/Voc_train.txt",
|
| 286 |
-
sr=44100,
|
| 287 |
-
clip_duration=10.0,
|
| 288 |
-
apply_augmentation=True,
|
| 289 |
-
rms_threshold=-30.0
|
| 290 |
-
)
|
| 291 |
-
|
| 292 |
-
sampler = InfiniteSampler(dataset)
|
| 293 |
-
iterator = iter(sampler)
|
| 294 |
-
|
| 295 |
-
output_dir = Path("./msr_test_set/Voc/")
|
| 296 |
-
output_dir.mkdir(parents=True, exist_ok=True)
|
| 297 |
-
logger.info(f"Output directory: {output_dir}")
|
| 298 |
-
|
| 299 |
-
for i in tqdm(range(10), desc="Generating test samples"):
|
| 300 |
-
index = next(iterator)
|
| 301 |
-
sample = dataset[index]
|
| 302 |
-
|
| 303 |
-
mixture_path = output_dir / f"mixture_{i}.wav"
|
| 304 |
-
target_path = output_dir / f"target_{i}.wav"
|
| 305 |
-
|
| 306 |
-
sf.write(mixture_path, sample["mixture"].T, dataset.sr)
|
| 307 |
-
sf.write(target_path, sample["target"].T, dataset.sr)
|
| 308 |
-
|
| 309 |
-
print("Test complete.")
|
|
|
|
| 8 |
from typing import List, Optional, Dict, Union, Tuple, Any
|
| 9 |
from torch.utils.data import Dataset, Sampler
|
| 10 |
from tqdm import tqdm
|
| 11 |
+
from data.augment import StemAugmentation, MixtureAugmentation
|
| 12 |
|
| 13 |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
| 14 |
logger = logging.getLogger(__name__)
|
|
|
|
| 84 |
self,
|
| 85 |
target_stem: str,
|
| 86 |
root_directory: Union[str, Path],
|
| 87 |
+
sr: int = 48000,
|
|
|
|
| 88 |
clip_duration: float = 3.0,
|
| 89 |
snr_range: Tuple[float, float] = (0.0, 10.0),
|
| 90 |
apply_augmentation: bool = True,
|
|
|
|
| 96 |
self.snr_range = snr_range
|
| 97 |
self.apply_augmentation = apply_augmentation
|
| 98 |
self.rms_threshold = rms_threshold
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 99 |
|
| 100 |
target_stem_parts = target_stem.split("_")
|
| 101 |
self.target_stem_1 = target_stem_parts[0].strip()
|
| 102 |
self.target_stem_2 = target_stem_parts[1].strip() if len(target_stem_parts) > 1 else None
|
| 103 |
|
| 104 |
+
logger.info(f"Scanning '{self.root_directory}' for songs containing stem '{target_stem}'...")
|
| 105 |
+
self.folders = []
|
| 106 |
+
for song_dir in self.root_directory.iterdir():
|
| 107 |
+
if song_dir.is_dir():
|
| 108 |
+
target_path = song_dir / self.target_stem_1
|
| 109 |
+
if self.target_stem_2:
|
| 110 |
+
target_path /= self.target_stem_2
|
| 111 |
+
|
| 112 |
+
if target_path.exists() and target_path.is_dir():
|
| 113 |
+
self.folders.append(song_dir)
|
| 114 |
+
|
| 115 |
+
if not self.folders:
|
| 116 |
+
raise FileNotFoundError(f"No subdirectories in '{self.root_directory}' were found containing the stem path '{target_stem}'. "
|
| 117 |
+
f"Please check your directory structure.")
|
| 118 |
+
logger.info(f"Found {len(self.folders)} song folders.")
|
| 119 |
+
|
| 120 |
self.audio_files = self._index_audio_files()
|
| 121 |
if not self.audio_files: raise ValueError("No audio files found.")
|
| 122 |
|
|
|
|
| 158 |
activity_masks[path_str] = np.array([False] * len(rms_values))
|
| 159 |
continue
|
| 160 |
|
|
|
|
| 161 |
is_loud = rms_values > self.rms_threshold
|
| 162 |
sum_loud = np.convolve(is_loud, np.ones(window_size), 'valid')
|
| 163 |
+
avg_loud_enough = sum_loud / window_size > 0.8
|
| 164 |
|
|
|
|
| 165 |
mask = np.zeros(len(rms_values), dtype=bool)
|
| 166 |
mask[:len(avg_loud_enough)] = avg_loud_enough
|
| 167 |
activity_masks[path_str] = mask
|
|
|
|
| 177 |
for file_path in file_paths:
|
| 178 |
path_str = str(file_path.relative_to(self.root_directory))
|
| 179 |
mask = self.activity_masks.get(path_str)
|
| 180 |
+
if mask is None: return []
|
| 181 |
masks_to_intersect.append(mask)
|
| 182 |
min_len = min(min_len, len(mask))
|
| 183 |
|
| 184 |
if not masks_to_intersect: return []
|
| 185 |
|
|
|
|
| 186 |
final_mask = np.ones(min_len, dtype=bool)
|
| 187 |
for mask in masks_to_intersect:
|
| 188 |
final_mask &= mask[:min_len]
|
|
|
|
| 209 |
if not is_target:
|
| 210 |
song_dict["others"].append(p)
|
| 211 |
except ValueError:
|
| 212 |
+
continue
|
| 213 |
|
| 214 |
if song_dict["target_stems"] and song_dict["others"]:
|
| 215 |
indexed_songs.append(song_dict)
|
|
|
|
| 231 |
start_second = random.choice(valid_starts)
|
| 232 |
offset = start_second + random.uniform(0, 1.0 - (self.clip_duration % 1.0 or 1.0))
|
| 233 |
|
|
|
|
| 234 |
target_mix = sum(load_audio(p, offset, self.clip_duration, self.sr) for p in selected_targets) / num_targets
|
| 235 |
other_mix = sum(load_audio(p, offset, self.clip_duration, self.sr) for p in selected_others) / num_others
|
| 236 |
|
| 237 |
if not contains_audio_signal(target_mix) or not contains_audio_signal(other_mix):
|
| 238 |
+
continue
|
| 239 |
|
| 240 |
target_clean = target_mix.copy()
|
| 241 |
target_augmented = self.stem_augmentation.apply(target_mix, self.sr) if self.apply_augmentation else target_mix
|
|
|
|
| 247 |
|
| 248 |
mixture_augmented = self.mixture_augmentation.apply(mixture, self.sr) if self.apply_augmentation else mixture
|
| 249 |
|
|
|
|
| 250 |
max_val = np.max(np.abs(mixture_augmented)) + 1e-8
|
| 251 |
mixture_final = mixture_augmented / max_val
|
| 252 |
target_final = target_clean / max_val
|
| 253 |
|
| 254 |
rescale = np.random.uniform(*DEFAULT_GAIN_RANGE)
|
| 255 |
+
|
| 256 |
+
mixture = np.nan_to_num(mixture_final * rescale)
|
| 257 |
+
target = np.nan_to_num(target_final * rescale)
|
| 258 |
+
|
| 259 |
+
target_length = int(self.clip_duration * self.sr)
|
| 260 |
+
if target.shape[1] != target_length:
|
| 261 |
+
target = np.pad(target, (0, target_length - target.shape[1]), mode='constant')
|
| 262 |
+
else:
|
| 263 |
+
target = target[:, :target_length]
|
| 264 |
+
if mixture.shape[1] != target_length:
|
| 265 |
+
mixture = np.pad(mixture, (0, target_length - mixture.shape[1]), mode='constant')
|
| 266 |
+
else:
|
| 267 |
+
mixture = mixture[:, :target_length]
|
| 268 |
|
| 269 |
return {
|
| 270 |
+
"mixture": np.nan_to_num(mixture),
|
| 271 |
+
"target": np.nan_to_num(target)
|
| 272 |
}
|
| 273 |
|
| 274 |
return self.__getitem__(random.randint(0, len(self.audio_files) - 1))
|
|
|
|
| 291 |
while True:
|
| 292 |
if self.pointer >= self.dataset_size: self.reset()
|
| 293 |
yield self.indexes[self.pointer]
|
| 294 |
+
self.pointer += 1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
evaluation/README.md
DELETED
|
@@ -1,31 +0,0 @@
|
|
| 1 |
-
# Evaluation Module
|
| 2 |
-
|
| 3 |
-
This directory contains classes for evaluating model performance during validation. All metrics inherit from a base `Metric` class for a consistent interface.
|
| 4 |
-
|
| 5 |
-
## Files
|
| 6 |
-
|
| 7 |
-
### `metrics.py`
|
| 8 |
-
|
| 9 |
-
#### `SI_SNR` (Scale-Invariant Signal-to-Noise Ratio)
|
| 10 |
-
|
| 11 |
-
A common metric for audio source separation that measures the quality of the restored signal relative to the original target. It is invariant to the overall scaling of the estimated signal.
|
| 12 |
-
|
| 13 |
-
- `update(pred, target)`: Updates the running statistics with a new batch of predicted and target audio tensors.
|
| 14 |
-
- `compute()`: Calculates the mean and standard deviation of the SI-SNR scores accumulated since the last reset.
|
| 15 |
-
- `reset()`: Clears the accumulated statistics.
|
| 16 |
-
|
| 17 |
-
#### `FAD_CLAP` (Fréchet Audio Distance using CLAP)
|
| 18 |
-
|
| 19 |
-
Measures the Fréchet distance between the distributions of embeddings from the generated audio and the ground truth audio. It uses a pre-trained CLAP (Contrastive Language-Audio Pretraining) model to generate these embeddings, providing a perceptually relevant measure of audio quality and similarity.
|
| 20 |
-
|
| 21 |
-
**Note:** This metric requires the `laion-clap` library. If not installed, it will fall back to using random embeddings, which is not meaningful for evaluation.
|
| 22 |
-
|
| 23 |
-
- `update(pred, target)`: Extracts CLAP embeddings from the predicted and target audio tensors and stores them.
|
| 24 |
-
- `compute()`: Calculates the FAD score between the collected sets of embeddings.
|
| 25 |
-
- `reset()`: Clears the stored embeddings.
|
| 26 |
-
|
| 27 |
-
**`__init__` Arguments:**
|
| 28 |
-
|
| 29 |
-
- `embedding_dim` (`int`): The dimensionality of the embeddings. Should match the CLAP model. Default: `512`.
|
| 30 |
-
- `model_name` (`str`): The name of the CLAP model architecture to use. Default: `'HTSAT-base'`.
|
| 31 |
-
- `ckpt_path` (`Optional[str]`): Optional path to a specific CLAP model checkpoint. If `None`, it uses the default pre-trained weights.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
evaluation/__init__.py
DELETED
|
File without changes
|
evaluation/metrics.py
DELETED
|
@@ -1,183 +0,0 @@
|
|
| 1 |
-
import torch
|
| 2 |
-
import torch.nn as nn
|
| 3 |
-
import logging
|
| 4 |
-
from typing import Dict, List, Optional, Any, Tuple
|
| 5 |
-
from abc import ABC, abstractmethod
|
| 6 |
-
|
| 7 |
-
try:
|
| 8 |
-
import laion_clap
|
| 9 |
-
except ImportError:
|
| 10 |
-
raise ImportError(
|
| 11 |
-
"The `laion_clap` package is required for the FAD metric. "
|
| 12 |
-
"Please install it with: pip install laion-clap"
|
| 13 |
-
)
|
| 14 |
-
|
| 15 |
-
class Metric(nn.Module, ABC):
|
| 16 |
-
def __init__(self):
|
| 17 |
-
super().__init__()
|
| 18 |
-
self.register_buffer("dummy_buffer", torch.empty(0))
|
| 19 |
-
|
| 20 |
-
@property
|
| 21 |
-
def device(self) -> torch.device:
|
| 22 |
-
return self.dummy_buffer.device
|
| 23 |
-
|
| 24 |
-
@abstractmethod
|
| 25 |
-
def reset(self):
|
| 26 |
-
raise NotImplementedError
|
| 27 |
-
|
| 28 |
-
@abstractmethod
|
| 29 |
-
def update(self, *args: Any, **kwargs: Any):
|
| 30 |
-
raise NotImplementedError
|
| 31 |
-
|
| 32 |
-
@abstractmethod
|
| 33 |
-
def compute(self) -> Dict[str, float]:
|
| 34 |
-
raise NotImplementedError
|
| 35 |
-
|
| 36 |
-
class SI_SNR(Metric):
|
| 37 |
-
def __init__(self, eps: float = 1e-8):
|
| 38 |
-
super().__init__()
|
| 39 |
-
self.eps = eps
|
| 40 |
-
self.reset()
|
| 41 |
-
|
| 42 |
-
def reset(self):
|
| 43 |
-
self.register_buffer("sum_scores", torch.tensor(0.0, dtype=torch.float64))
|
| 44 |
-
self.register_buffer("sum_sq_scores", torch.tensor(0.0, dtype=torch.float64))
|
| 45 |
-
self.register_buffer("count", torch.tensor(0, dtype=torch.int64))
|
| 46 |
-
|
| 47 |
-
def update(self, pred: torch.Tensor, target: torch.Tensor):
|
| 48 |
-
score = self._compute_si_snr(pred, target).detach()
|
| 49 |
-
self.sum_scores += torch.sum(score)
|
| 50 |
-
self.sum_sq_scores += torch.sum(score.pow(2))
|
| 51 |
-
self.count += score.numel()
|
| 52 |
-
|
| 53 |
-
def compute(self) -> Dict[str, float]:
|
| 54 |
-
if self.count.item() == 0:
|
| 55 |
-
return {'mean': 0.0, 'std': 0.0, 'count': 0}
|
| 56 |
-
|
| 57 |
-
total_count = self.count.item()
|
| 58 |
-
mean_val = (self.sum_scores / self.count).item()
|
| 59 |
-
var = (self.sum_sq_scores / self.count) - (self.sum_scores / self.count).pow(2)
|
| 60 |
-
std_val = torch.sqrt(var).item() if var > 0 and total_count > 1 else 0.0
|
| 61 |
-
|
| 62 |
-
return {'mean': mean_val, 'std': std_val, 'count': int(total_count)}
|
| 63 |
-
|
| 64 |
-
def _compute_si_snr(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
|
| 65 |
-
pred = pred.view(-1, pred.shape[-1])
|
| 66 |
-
target = target.view(-1, target.shape[-1])
|
| 67 |
-
pred_zm = pred - pred.mean(dim=-1, keepdim=True)
|
| 68 |
-
target_zm = target - target.mean(dim=-1, keepdim=True)
|
| 69 |
-
alpha = (pred_zm * target_zm).sum(dim=-1, keepdim=True) / \
|
| 70 |
-
(target_zm.pow(2).sum(dim=-1, keepdim=True) + self.eps)
|
| 71 |
-
target_scaled = alpha * target_zm
|
| 72 |
-
noise = pred_zm - target_scaled
|
| 73 |
-
si_snr_val = (target_scaled.pow(2).sum(dim=-1)) / \
|
| 74 |
-
(noise.pow(2).sum(dim=-1) + self.eps)
|
| 75 |
-
return 10 * torch.log10(si_snr_val + self.eps)
|
| 76 |
-
|
| 77 |
-
class FAD_CLAP(Metric):
|
| 78 |
-
def __init__(self, embedding_dim: int = 512, model_name: str = 'HTSAT-base', ckpt_path: Optional[str] = None):
|
| 79 |
-
super().__init__()
|
| 80 |
-
self.embedding_dim = embedding_dim
|
| 81 |
-
self.clap_model = self._load_clap_model(model_name, ckpt_path)
|
| 82 |
-
|
| 83 |
-
self.pred_embeddings: List[torch.Tensor] = []
|
| 84 |
-
self.target_embeddings: List[torch.Tensor] = []
|
| 85 |
-
self.reset()
|
| 86 |
-
|
| 87 |
-
def _load_clap_model(self, model_name: str, ckpt_path: Optional[str]) -> Optional[nn.Module]:
|
| 88 |
-
if laion_clap is None:
|
| 89 |
-
logging.warning("`laion_clap` is not installed. FAD will use random embeddings.")
|
| 90 |
-
return None
|
| 91 |
-
try:
|
| 92 |
-
logging.info(f"Loading CLAP model '{model_name}' for FAD metric...")
|
| 93 |
-
model = laion_clap.CLAP_Module(enable_fusion=False, amodel=model_name)
|
| 94 |
-
model.load_ckpt(ckpt_path)
|
| 95 |
-
model.eval()
|
| 96 |
-
logging.info("CLAP model loaded successfully.")
|
| 97 |
-
return model
|
| 98 |
-
except Exception as e:
|
| 99 |
-
logging.warning(f"Failed to load CLAP model due to an error: {e}. FAD will use random embeddings.")
|
| 100 |
-
return None
|
| 101 |
-
|
| 102 |
-
def to(self, *args, **kwargs):
|
| 103 |
-
super().to(*args, **kwargs)
|
| 104 |
-
return self
|
| 105 |
-
|
| 106 |
-
def reset(self):
|
| 107 |
-
self.pred_embeddings.clear()
|
| 108 |
-
self.target_embeddings.clear()
|
| 109 |
-
|
| 110 |
-
def update(self, pred: torch.Tensor, target: torch.Tensor):
|
| 111 |
-
self.pred_embeddings.append(self._extract_embedding(pred).cpu())
|
| 112 |
-
self.target_embeddings.append(self._extract_embedding(target).cpu())
|
| 113 |
-
|
| 114 |
-
def compute(self) -> Dict[str, float]:
|
| 115 |
-
if not self.pred_embeddings or not self.target_embeddings:
|
| 116 |
-
return {'fad': float('inf'), 'count': 0}
|
| 117 |
-
|
| 118 |
-
pred_emb_all = torch.cat(self.pred_embeddings, dim=0).to(self.device)
|
| 119 |
-
target_emb_all = torch.cat(self.target_embeddings, dim=0).to(self.device)
|
| 120 |
-
|
| 121 |
-
if pred_emb_all.shape[0] < 2 or target_emb_all.shape[0] < 2:
|
| 122 |
-
logging.warning(f"FAD requires at least 2 samples per set, but got {pred_emb_all.shape[0]} and {target_emb_all.shape[0]}.")
|
| 123 |
-
return {'fad': float('inf'), 'count': pred_emb_all.shape[0]}
|
| 124 |
-
|
| 125 |
-
mu_pred, sigma_pred = self._get_mu_and_sigma(pred_emb_all)
|
| 126 |
-
mu_target, sigma_target = self._get_mu_and_sigma(target_emb_all)
|
| 127 |
-
fad_score = self._frechet_distance(mu_pred, sigma_pred, mu_target, sigma_target)
|
| 128 |
-
return {'fad': fad_score.item(), 'count': len(pred_emb_all)}
|
| 129 |
-
|
| 130 |
-
@torch.no_grad()
|
| 131 |
-
def _extract_embedding(self, audio: torch.Tensor) -> torch.Tensor:
|
| 132 |
-
if self.clap_model is None:
|
| 133 |
-
return torch.randn(audio.shape[0], self.embedding_dim, device=audio.device)
|
| 134 |
-
|
| 135 |
-
self.clap_model.to(audio.device)
|
| 136 |
-
|
| 137 |
-
audio_dict = {'waveform': audio, 'sample_rate': 48000}
|
| 138 |
-
return self.clap_model.get_audio_embedding_from_data(x=audio_dict, use_tensor=True)
|
| 139 |
-
|
| 140 |
-
def _get_mu_and_sigma(self, embeddings: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 141 |
-
mu = embeddings.mean(dim=0)
|
| 142 |
-
sigma = torch.cov(embeddings.T)
|
| 143 |
-
return mu, sigma
|
| 144 |
-
|
| 145 |
-
def _frechet_distance(self, mu1, sigma1, mu2, sigma2) -> torch.Tensor:
|
| 146 |
-
diff = mu1 - mu2
|
| 147 |
-
mean_dist_sq = diff.dot(diff)
|
| 148 |
-
try:
|
| 149 |
-
offset = torch.eye(sigma1.shape[0], device=self.device, dtype=sigma1.dtype) * 1e-6
|
| 150 |
-
cov_sqrt = torch.linalg.sqrtm((sigma1 + offset) @ (sigma2 + offset)).real
|
| 151 |
-
except RuntimeError:
|
| 152 |
-
logging.warning("Matrix square root failed. Using diagonal approximation for FAD.")
|
| 153 |
-
cov_sqrt = torch.sqrt(torch.diag(sigma1) * torch.diag(sigma2))
|
| 154 |
-
trace_term = torch.trace(sigma1) + torch.trace(sigma2) - 2 * torch.trace(cov_sqrt)
|
| 155 |
-
return mean_dist_sq + trace_term
|
| 156 |
-
|
| 157 |
-
if __name__ == '__main__':
|
| 158 |
-
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 159 |
-
print("Initializing FAD metric...")
|
| 160 |
-
fad_metric = FAD_CLAP()
|
| 161 |
-
fad_metric.to(device)
|
| 162 |
-
|
| 163 |
-
sample_rate = 48000
|
| 164 |
-
dummy_pred_audio_batch1 = torch.randn(4, sample_rate * 2, device=device)
|
| 165 |
-
dummy_target_audio_batch1 = torch.randn(4, sample_rate * 2, device=device)
|
| 166 |
-
|
| 167 |
-
dummy_pred_audio_batch2 = torch.randn(4, sample_rate * 2, device=device)
|
| 168 |
-
dummy_target_audio_batch2 = torch.randn(4, sample_rate * 2, device=device)
|
| 169 |
-
|
| 170 |
-
print("\nUpdating metric with batch 1...")
|
| 171 |
-
fad_metric.update(pred=dummy_pred_audio_batch1, target=dummy_target_audio_batch1)
|
| 172 |
-
|
| 173 |
-
print("Updating metric with batch 2...")
|
| 174 |
-
fad_metric.update(pred=dummy_pred_audio_batch2, target=dummy_target_audio_batch2)
|
| 175 |
-
|
| 176 |
-
print("\nComputing final FAD score...")
|
| 177 |
-
final_fad_score = fad_metric.compute()
|
| 178 |
-
|
| 179 |
-
print(f"Final FAD results: {final_fad_score}")
|
| 180 |
-
|
| 181 |
-
fad_metric.reset()
|
| 182 |
-
print("\nMetric has been reset.")
|
| 183 |
-
print(f"State after reset: pred_embeddings={fad_metric.pred_embeddings}, target_embeddings={fad_metric.target_embeddings}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
inference.py
ADDED
|
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import yaml
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
from typing import Dict, Any
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
import soundfile as sf
|
| 9 |
+
import numpy as np
|
| 10 |
+
from tqdm import tqdm
|
| 11 |
+
|
| 12 |
+
from models import MelRNN, MelRoFormer, UNet
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def load_generator(config: Dict[str, Any], checkpoint_path: str, device: str = 'cuda') -> nn.Module:
|
| 16 |
+
"""Initialize and load the generator model from unwrapped checkpoint."""
|
| 17 |
+
model_cfg = config['model']
|
| 18 |
+
|
| 19 |
+
# Initialize generator based on config
|
| 20 |
+
if model_cfg['name'] == 'MelRNN':
|
| 21 |
+
generator = MelRNN.MelRNN(**model_cfg['params'])
|
| 22 |
+
elif model_cfg['name'] == 'MelRoFormer':
|
| 23 |
+
generator = MelRoFormer.MelRoFormer(**model_cfg['params'])
|
| 24 |
+
elif model_cfg['name'] == 'MelUNet':
|
| 25 |
+
generator = UNet.MelUNet(**model_cfg['params'])
|
| 26 |
+
else:
|
| 27 |
+
raise ValueError(f"Unknown model name: {model_cfg['name']}")
|
| 28 |
+
|
| 29 |
+
# Load unwrapped generator weights
|
| 30 |
+
state_dict = torch.load(checkpoint_path, map_location=device)
|
| 31 |
+
generator.load_state_dict(state_dict)
|
| 32 |
+
|
| 33 |
+
generator = generator.to(device)
|
| 34 |
+
generator.eval()
|
| 35 |
+
|
| 36 |
+
return generator
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def process_audio(audio: np.ndarray, generator: nn.Module, device: str = 'cuda') -> np.ndarray:
|
| 40 |
+
"""Process a single audio array through the generator."""
|
| 41 |
+
# Convert to tensor: (channels, samples) -> (1, channels, samples)
|
| 42 |
+
if audio.ndim == 1:
|
| 43 |
+
audio = audio[np.newaxis, :] # Add channel dimension for mono
|
| 44 |
+
|
| 45 |
+
audio_tensor = torch.from_numpy(audio).float().to(device)
|
| 46 |
+
|
| 47 |
+
# Run inference
|
| 48 |
+
with torch.no_grad():
|
| 49 |
+
output_tensor = generator(audio_tensor)
|
| 50 |
+
|
| 51 |
+
# Convert back to numpy: (1, channels, samples) -> (channels, samples)
|
| 52 |
+
output_audio = output_tensor.cpu().numpy()
|
| 53 |
+
|
| 54 |
+
return output_audio
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def main():
|
| 58 |
+
parser = argparse.ArgumentParser(description="Run inference on audio files using trained generator")
|
| 59 |
+
parser.add_argument("--config", type=str, required=True, help="Path to config.yaml")
|
| 60 |
+
parser.add_argument("--checkpoint", type=str, required=True, help="Path to unwrapped generator weights (.pth)")
|
| 61 |
+
parser.add_argument("--input_dir", type=str, required=True, help="Directory containing input .flac files")
|
| 62 |
+
parser.add_argument("--output_dir", type=str, required=True, help="Directory to save processed audio")
|
| 63 |
+
parser.add_argument("--device", type=str, default="cuda", help="Device to run inference on (cuda/cpu)")
|
| 64 |
+
args = parser.parse_args()
|
| 65 |
+
|
| 66 |
+
# Load config
|
| 67 |
+
with open(args.config, 'r') as f:
|
| 68 |
+
config = yaml.safe_load(f)
|
| 69 |
+
|
| 70 |
+
# Setup paths
|
| 71 |
+
input_dir = Path(args.input_dir)
|
| 72 |
+
output_dir = Path(args.output_dir)
|
| 73 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
| 74 |
+
|
| 75 |
+
# Get all .flac files
|
| 76 |
+
audio_files = sorted(input_dir.glob("*.flac"))
|
| 77 |
+
|
| 78 |
+
if len(audio_files) == 0:
|
| 79 |
+
print(f"No .flac files found in {input_dir}")
|
| 80 |
+
return
|
| 81 |
+
|
| 82 |
+
print(f"Found {len(audio_files)} audio files")
|
| 83 |
+
|
| 84 |
+
# Load model
|
| 85 |
+
print(f"Loading generator from {args.checkpoint}...")
|
| 86 |
+
generator = load_generator(config, args.checkpoint, device=args.device)
|
| 87 |
+
print("Model loaded successfully")
|
| 88 |
+
|
| 89 |
+
# Process each file
|
| 90 |
+
for audio_file in tqdm(audio_files, desc="Processing audio files"):
|
| 91 |
+
# Load audio
|
| 92 |
+
audio, sr = sf.read(audio_file)
|
| 93 |
+
|
| 94 |
+
# Transpose if needed: soundfile loads as (samples, channels)
|
| 95 |
+
if audio.ndim == 2:
|
| 96 |
+
audio = audio.T # Convert to (channels, samples)
|
| 97 |
+
|
| 98 |
+
# Process through generator
|
| 99 |
+
output_audio = process_audio(audio, generator, device=args.device)
|
| 100 |
+
|
| 101 |
+
# Transpose back for saving: (channels, samples) -> (samples, channels)
|
| 102 |
+
if output_audio.ndim == 2:
|
| 103 |
+
output_audio = output_audio.T
|
| 104 |
+
|
| 105 |
+
# Save with same filename
|
| 106 |
+
output_path = output_dir / audio_file.name
|
| 107 |
+
sf.write(output_path, output_audio, sr)
|
| 108 |
+
|
| 109 |
+
print(f"\nProcessing complete! Output saved to {output_dir}")
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
if __name__ == '__main__':
|
| 113 |
+
main()
|
models/MelRNN.py
CHANGED
|
@@ -25,8 +25,8 @@ class MelRNN(nn.Module):
|
|
| 25 |
|
| 26 |
def forward(self, x):
|
| 27 |
original_length = x.shape[1]
|
| 28 |
-
|
| 29 |
-
x = self.band.split(
|
| 30 |
|
| 31 |
x = rearrange(x, 'b c t f -> b t f c')
|
| 32 |
b, t, f, c = x.shape
|
|
@@ -40,13 +40,12 @@ class MelRNN(nn.Module):
|
|
| 40 |
x = rearrange(x, '(b f) t c -> b t f c', f=f)
|
| 41 |
|
| 42 |
x = rearrange(x, 'b t f c -> b c t f')
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
x = self.fourier.istft(identity, original_length)
|
| 46 |
return x
|
| 47 |
|
| 48 |
if __name__ == "__main__":
|
| 49 |
-
model = MelRNN(hidden_channels=128, num_layers=
|
| 50 |
|
| 51 |
x = torch.randn(4, 96000)
|
| 52 |
|
|
|
|
| 25 |
|
| 26 |
def forward(self, x):
|
| 27 |
original_length = x.shape[1]
|
| 28 |
+
x = self.fourier.stft(x)
|
| 29 |
+
x = self.band.split(x) # (B, C, T, F)
|
| 30 |
|
| 31 |
x = rearrange(x, 'b c t f -> b t f c')
|
| 32 |
b, t, f, c = x.shape
|
|
|
|
| 40 |
x = rearrange(x, '(b f) t c -> b t f c', f=f)
|
| 41 |
|
| 42 |
x = rearrange(x, 'b t f c -> b c t f')
|
| 43 |
+
x = self.band.unsplit(x)
|
| 44 |
+
x = self.fourier.istft(x.contiguous(), original_length)
|
|
|
|
| 45 |
return x
|
| 46 |
|
| 47 |
if __name__ == "__main__":
|
| 48 |
+
model = MelRNN(hidden_channels=128, num_layers=9, num_groups=8, window_size=2048, hop_size=512, sample_rate=48000)
|
| 49 |
|
| 50 |
x = torch.randn(4, 96000)
|
| 51 |
|
models/UNet.py
CHANGED
|
@@ -23,8 +23,8 @@ class MelUNet(nn.Module):
|
|
| 23 |
|
| 24 |
def forward(self, x):
|
| 25 |
original_length = x.shape[1]
|
| 26 |
-
|
| 27 |
-
x = self.band.split(
|
| 28 |
|
| 29 |
residuals = []
|
| 30 |
for i in range(self.num_layers):
|
|
@@ -37,14 +37,13 @@ class MelUNet(nn.Module):
|
|
| 37 |
if i < self.num_layers - 1:
|
| 38 |
x = x + residual
|
| 39 |
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
x = self.fourier.istft(identity, original_length)
|
| 43 |
return x
|
| 44 |
|
| 45 |
if __name__ == "__main__":
|
| 46 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 47 |
-
model = MelUNet(hidden_channels=
|
| 48 |
|
| 49 |
x = torch.randn(4, 96000)
|
| 50 |
x = x.to(device)
|
|
|
|
| 23 |
|
| 24 |
def forward(self, x):
|
| 25 |
original_length = x.shape[1]
|
| 26 |
+
x = self.fourier.stft(x)
|
| 27 |
+
x = self.band.split(x) # (B, C, T, F)
|
| 28 |
|
| 29 |
residuals = []
|
| 30 |
for i in range(self.num_layers):
|
|
|
|
| 37 |
if i < self.num_layers - 1:
|
| 38 |
x = x + residual
|
| 39 |
|
| 40 |
+
x = self.band.unsplit(x)
|
| 41 |
+
x = self.fourier.istft(x.contiguous(), original_length)
|
|
|
|
| 42 |
return x
|
| 43 |
|
| 44 |
if __name__ == "__main__":
|
| 45 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 46 |
+
model = MelUNet(hidden_channels=32, num_layers=4, upsampling_factor=2, window_size=2048, hop_size=512, sample_rate=48000)
|
| 47 |
|
| 48 |
x = torch.randn(4, 96000)
|
| 49 |
x = x.to(device)
|
modules/generator/ConvNeXt2DBlock.py
CHANGED
|
@@ -35,7 +35,7 @@ class ConvNeXt2DBlock(nn.Module):
|
|
| 35 |
self.dwconv = nn.ConvTranspose2d(dim, dim, kernel_size=kernel_size, stride=stride, padding=self.padding)
|
| 36 |
self.residual_conv = nn.ConvTranspose2d(dim, output_dim, kernel_size=kernel_size, stride=stride, padding=self.padding)
|
| 37 |
self.norm = RMSNorm(dim)
|
| 38 |
-
self.n_hidden = int(
|
| 39 |
self.pwconv1 = nn.Linear(dim, self.n_hidden * 2)
|
| 40 |
self.pwconv2 = nn.Linear(self.n_hidden, output_dim)
|
| 41 |
|
|
|
|
| 35 |
self.dwconv = nn.ConvTranspose2d(dim, dim, kernel_size=kernel_size, stride=stride, padding=self.padding)
|
| 36 |
self.residual_conv = nn.ConvTranspose2d(dim, output_dim, kernel_size=kernel_size, stride=stride, padding=self.padding)
|
| 37 |
self.norm = RMSNorm(dim)
|
| 38 |
+
self.n_hidden = int(4 * dim / 3)
|
| 39 |
self.pwconv1 = nn.Linear(dim, self.n_hidden * 2)
|
| 40 |
self.pwconv2 = nn.Linear(self.n_hidden, output_dim)
|
| 41 |
|
train.py
CHANGED
|
@@ -2,7 +2,7 @@ import argparse
|
|
| 2 |
import yaml
|
| 3 |
from pathlib import Path
|
| 4 |
from typing import Dict, Any, List
|
| 5 |
-
|
| 6 |
import torch
|
| 7 |
import torch.nn as nn
|
| 8 |
from torch.utils.data import DataLoader
|
|
@@ -10,16 +10,10 @@ import pytorch_lightning as pl
|
|
| 10 |
from pytorch_lightning.loggers import TensorBoardLogger
|
| 11 |
from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor
|
| 12 |
|
| 13 |
-
import numpy as np
|
| 14 |
-
import soundfile as sf
|
| 15 |
-
import matplotlib.pyplot as plt
|
| 16 |
-
import librosa
|
| 17 |
-
|
| 18 |
from data.dataset import RawStems, InfiniteSampler
|
| 19 |
from models import MelRNN, MelRoFormer, UNet
|
| 20 |
from losses.gan_loss import GeneratorLoss, DiscriminatorLoss, FeatureMatchingLoss
|
| 21 |
from losses.reconstruction_loss import MultiMelSpecReconstructionLoss
|
| 22 |
-
from evaluation.metrics import SI_SNR, FAD_CLAP
|
| 23 |
|
| 24 |
from modules.discriminator.MultiPeriodDiscriminator import MultiPeriodDiscriminator
|
| 25 |
from modules.discriminator.MultiScaleDiscriminator import MultiScaleDiscriminator
|
|
@@ -55,12 +49,11 @@ class CombinedDiscriminator(nn.Module):
|
|
| 55 |
return all_scores, all_fmaps
|
| 56 |
|
| 57 |
class MusicRestorationDataModule(pl.LightningDataModule):
|
| 58 |
-
"""Handles data loading for training
|
| 59 |
def __init__(self, config: Dict[str, Any]):
|
| 60 |
super().__init__()
|
| 61 |
self.config = config
|
| 62 |
self.train_dataset = None
|
| 63 |
-
self.val_dataset = None
|
| 64 |
|
| 65 |
def setup(self, stage: str | None = None):
|
| 66 |
common_params = {
|
|
@@ -68,7 +61,6 @@ class MusicRestorationDataModule(pl.LightningDataModule):
|
|
| 68 |
"clip_duration": self.config['clip_duration'],
|
| 69 |
}
|
| 70 |
self.train_dataset = RawStems(**self.config['train_dataset'], **common_params)
|
| 71 |
-
self.val_dataset = RawStems(**self.config['val_dataset'], **common_params)
|
| 72 |
|
| 73 |
def train_dataloader(self):
|
| 74 |
sampler = InfiniteSampler(self.train_dataset)
|
|
@@ -77,13 +69,6 @@ class MusicRestorationDataModule(pl.LightningDataModule):
|
|
| 77 |
sampler=sampler,
|
| 78 |
**self.config['dataloader_params']
|
| 79 |
)
|
| 80 |
-
|
| 81 |
-
def val_dataloader(self):
|
| 82 |
-
return DataLoader(
|
| 83 |
-
self.val_dataset,
|
| 84 |
-
shuffle=False,
|
| 85 |
-
**self.config['dataloader_params']
|
| 86 |
-
)
|
| 87 |
|
| 88 |
class MusicRestorationModule(pl.LightningModule):
|
| 89 |
"""
|
|
@@ -108,18 +93,12 @@ class MusicRestorationModule(pl.LightningModule):
|
|
| 108 |
self.loss_feat = FeatureMatchingLoss()
|
| 109 |
self.loss_recon = MultiMelSpecReconstructionLoss(**loss_cfg['reconstruction_loss'])
|
| 110 |
|
| 111 |
-
# 4. Validation Metrics
|
| 112 |
-
self.val_si_snr = SI_SNR()
|
| 113 |
-
# Note: FAD_CLAP requires `laion_clap` to be installed.
|
| 114 |
-
# It will gracefully fall back to random embeddings if not found.
|
| 115 |
-
self.val_fad = FAD_CLAP()
|
| 116 |
-
|
| 117 |
def _init_generator(self):
|
| 118 |
model_cfg = self.hparams.model
|
| 119 |
if model_cfg['name'] == 'MelRNN':
|
| 120 |
-
return MelRNN(**model_cfg['params'])
|
| 121 |
elif model_cfg['name'] == 'MelRoFormer':
|
| 122 |
-
return MelRoFormer(**model_cfg['params'])
|
| 123 |
elif model_cfg['name'] == 'MelUNet':
|
| 124 |
return UNet.MelUNet(**model_cfg['params'])
|
| 125 |
else:
|
|
@@ -133,12 +112,16 @@ class MusicRestorationModule(pl.LightningModule):
|
|
| 133 |
|
| 134 |
target = batch['target']
|
| 135 |
mixture = batch['mixture']
|
|
|
|
|
|
|
|
|
|
|
|
|
| 136 |
|
| 137 |
# --- Train Discriminator ---
|
| 138 |
generated = self(mixture)
|
| 139 |
|
| 140 |
-
real_scores, _ = self.discriminator(target)
|
| 141 |
-
fake_scores, _ = self.discriminator(generated.detach())
|
| 142 |
|
| 143 |
d_loss, _, _ = self.loss_disc_adv(real_scores, fake_scores)
|
| 144 |
|
|
@@ -148,8 +131,8 @@ class MusicRestorationModule(pl.LightningModule):
|
|
| 148 |
self.log('train/d_loss', d_loss, prog_bar=True)
|
| 149 |
|
| 150 |
# --- Train Generator ---
|
| 151 |
-
real_scores, real_fmaps = self.discriminator(target)
|
| 152 |
-
fake_scores, fake_fmaps = self.discriminator(generated)
|
| 153 |
|
| 154 |
# Reconstruction Loss
|
| 155 |
loss_recon = self.loss_recon(generated, target)
|
|
@@ -180,52 +163,6 @@ class MusicRestorationModule(pl.LightningModule):
|
|
| 180 |
sch_g, sch_d = self.lr_schedulers()
|
| 181 |
if sch_g: sch_g.step()
|
| 182 |
if sch_d: sch_d.step()
|
| 183 |
-
|
| 184 |
-
def validation_step(self, batch: Dict[str, torch.Tensor], batch_idx: int):
|
| 185 |
-
target = batch['target']
|
| 186 |
-
mixture = batch['mixture']
|
| 187 |
-
|
| 188 |
-
generated = self(mixture)
|
| 189 |
-
|
| 190 |
-
loss_recon = self.loss_recon(generated, target)
|
| 191 |
-
self.log('val/loss_recon', loss_recon, on_step=False, on_epoch=True, sync_dist=True)
|
| 192 |
-
|
| 193 |
-
self.val_si_snr.update(generated.detach(), target.detach())
|
| 194 |
-
self.val_fad.update(generated.detach(), target.detach())
|
| 195 |
-
|
| 196 |
-
# Log one audio example and spectrogram per validation epoch
|
| 197 |
-
if batch_idx == 0:
|
| 198 |
-
self._log_media(mixture[0], target[0], generated[0])
|
| 199 |
-
|
| 200 |
-
def on_validation_epoch_end(self):
|
| 201 |
-
si_snr_results = self.val_si_snr.compute()
|
| 202 |
-
fad_results = self.val_fad.compute()
|
| 203 |
-
|
| 204 |
-
self.log('val/si_snr', si_snr_results['mean'], sync_dist=True)
|
| 205 |
-
self.log('val/fad', fad_results['fad'], sync_dist=True)
|
| 206 |
-
|
| 207 |
-
self.val_si_snr.reset()
|
| 208 |
-
self.val_fad.reset()
|
| 209 |
-
|
| 210 |
-
def _log_media(self, mixture: torch.Tensor, target: torch.Tensor, generated: torch.Tensor):
|
| 211 |
-
sr = self.hparams.data['sample_rate']
|
| 212 |
-
|
| 213 |
-
# Log audio
|
| 214 |
-
self.logger.experiment.add_audio("val_audio/mixture", mixture.mean(0).cpu(), self.global_step, sample_rate=sr)
|
| 215 |
-
self.logger.experiment.add_audio("val_audio/target", target.mean(0).cpu(), self.global_step, sample_rate=sr)
|
| 216 |
-
self.logger.experiment.add_audio("val_audio/generated", generated.mean(0).cpu(), self.global_step, sample_rate=sr)
|
| 217 |
-
|
| 218 |
-
# Log spectrograms
|
| 219 |
-
fig, axes = plt.subplots(3, 1, figsize=(10, 12))
|
| 220 |
-
for i, (title, audio) in enumerate([("Mixture", mixture), ("Target", target), ("Generated", generated)]):
|
| 221 |
-
audio_np = audio.mean(0).cpu().numpy().astype(np.float32)
|
| 222 |
-
mel_spec = librosa.feature.melspectrogram(y=audio_np, sr=sr)
|
| 223 |
-
mel_spec_db = librosa.power_to_db(mel_spec, ref=np.max)
|
| 224 |
-
librosa.display.specshow(mel_spec_db, sr=sr, x_axis='time', y_axis='mel', ax=axes[i])
|
| 225 |
-
axes[i].set_title(title)
|
| 226 |
-
plt.tight_layout()
|
| 227 |
-
self.logger.experiment.add_figure("val_spectrograms", fig, self.global_step)
|
| 228 |
-
plt.close(fig)
|
| 229 |
|
| 230 |
def configure_optimizers(self):
|
| 231 |
# Generator Optimizer
|
|
@@ -248,7 +185,7 @@ class MusicRestorationModule(pl.LightningModule):
|
|
| 248 |
|
| 249 |
def main():
|
| 250 |
parser = argparse.ArgumentParser(description="Train a Music Source Restoration Model")
|
| 251 |
-
parser.add_argument("--config", type=str, required=True, help="Path to the config
|
| 252 |
args = parser.parse_args()
|
| 253 |
|
| 254 |
with open(args.config, 'r') as f:
|
|
@@ -258,24 +195,26 @@ def main():
|
|
| 258 |
|
| 259 |
data_module = MusicRestorationDataModule(config['data'])
|
| 260 |
model_module = MusicRestorationModule(config)
|
| 261 |
-
|
| 262 |
-
|
|
|
|
|
|
|
| 263 |
|
| 264 |
# Callbacks
|
| 265 |
checkpoint_callback = ModelCheckpoint(
|
| 266 |
dirpath=save_dir / "checkpoints",
|
| 267 |
-
filename="{step:08d}
|
| 268 |
-
every_n_train_steps=config['trainer']['
|
| 269 |
-
save_top_k=-1,
|
| 270 |
auto_insert_metric_name=False
|
| 271 |
)
|
| 272 |
lr_monitor = LearningRateMonitor(logging_interval='step')
|
| 273 |
|
| 274 |
# Logger
|
| 275 |
logger = TensorBoardLogger(
|
| 276 |
-
save_dir=
|
| 277 |
name=config['project_name'],
|
| 278 |
-
version=
|
| 279 |
)
|
| 280 |
|
| 281 |
# Trainer
|
|
@@ -283,11 +222,10 @@ def main():
|
|
| 283 |
logger=logger,
|
| 284 |
callbacks=[checkpoint_callback, lr_monitor],
|
| 285 |
max_steps=config['trainer']['max_steps'],
|
| 286 |
-
val_check_interval=config['trainer']['val_check_interval'],
|
| 287 |
log_every_n_steps=config['trainer']['log_every_n_steps'],
|
| 288 |
devices=config['trainer']['devices'],
|
| 289 |
precision=config['trainer']['precision'],
|
| 290 |
-
accelerator="gpu"
|
| 291 |
)
|
| 292 |
|
| 293 |
trainer.fit(model_module, datamodule=data_module)
|
|
|
|
| 2 |
import yaml
|
| 3 |
from pathlib import Path
|
| 4 |
from typing import Dict, Any, List
|
| 5 |
+
from einops import rearrange
|
| 6 |
import torch
|
| 7 |
import torch.nn as nn
|
| 8 |
from torch.utils.data import DataLoader
|
|
|
|
| 10 |
from pytorch_lightning.loggers import TensorBoardLogger
|
| 11 |
from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor
|
| 12 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
from data.dataset import RawStems, InfiniteSampler
|
| 14 |
from models import MelRNN, MelRoFormer, UNet
|
| 15 |
from losses.gan_loss import GeneratorLoss, DiscriminatorLoss, FeatureMatchingLoss
|
| 16 |
from losses.reconstruction_loss import MultiMelSpecReconstructionLoss
|
|
|
|
| 17 |
|
| 18 |
from modules.discriminator.MultiPeriodDiscriminator import MultiPeriodDiscriminator
|
| 19 |
from modules.discriminator.MultiScaleDiscriminator import MultiScaleDiscriminator
|
|
|
|
| 49 |
return all_scores, all_fmaps
|
| 50 |
|
| 51 |
class MusicRestorationDataModule(pl.LightningDataModule):
|
| 52 |
+
"""Handles data loading for training."""
|
| 53 |
def __init__(self, config: Dict[str, Any]):
|
| 54 |
super().__init__()
|
| 55 |
self.config = config
|
| 56 |
self.train_dataset = None
|
|
|
|
| 57 |
|
| 58 |
def setup(self, stage: str | None = None):
|
| 59 |
common_params = {
|
|
|
|
| 61 |
"clip_duration": self.config['clip_duration'],
|
| 62 |
}
|
| 63 |
self.train_dataset = RawStems(**self.config['train_dataset'], **common_params)
|
|
|
|
| 64 |
|
| 65 |
def train_dataloader(self):
|
| 66 |
sampler = InfiniteSampler(self.train_dataset)
|
|
|
|
| 69 |
sampler=sampler,
|
| 70 |
**self.config['dataloader_params']
|
| 71 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 72 |
|
| 73 |
class MusicRestorationModule(pl.LightningModule):
|
| 74 |
"""
|
|
|
|
| 93 |
self.loss_feat = FeatureMatchingLoss()
|
| 94 |
self.loss_recon = MultiMelSpecReconstructionLoss(**loss_cfg['reconstruction_loss'])
|
| 95 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 96 |
def _init_generator(self):
|
| 97 |
model_cfg = self.hparams.model
|
| 98 |
if model_cfg['name'] == 'MelRNN':
|
| 99 |
+
return MelRNN.MelRNN(**model_cfg['params'])
|
| 100 |
elif model_cfg['name'] == 'MelRoFormer':
|
| 101 |
+
return MelRoFormer.MelRoFormer(**model_cfg['params'])
|
| 102 |
elif model_cfg['name'] == 'MelUNet':
|
| 103 |
return UNet.MelUNet(**model_cfg['params'])
|
| 104 |
else:
|
|
|
|
| 112 |
|
| 113 |
target = batch['target']
|
| 114 |
mixture = batch['mixture']
|
| 115 |
+
|
| 116 |
+
# reshape both from (b, c, t) to ((b, c) t)
|
| 117 |
+
target = rearrange(target, 'b c t -> (b c) t')
|
| 118 |
+
mixture = rearrange(mixture, 'b c t -> (b c) t')
|
| 119 |
|
| 120 |
# --- Train Discriminator ---
|
| 121 |
generated = self(mixture)
|
| 122 |
|
| 123 |
+
real_scores, _ = self.discriminator(target.unsqueeze(1))
|
| 124 |
+
fake_scores, _ = self.discriminator(generated.detach().unsqueeze(1))
|
| 125 |
|
| 126 |
d_loss, _, _ = self.loss_disc_adv(real_scores, fake_scores)
|
| 127 |
|
|
|
|
| 131 |
self.log('train/d_loss', d_loss, prog_bar=True)
|
| 132 |
|
| 133 |
# --- Train Generator ---
|
| 134 |
+
real_scores, real_fmaps = self.discriminator(target.unsqueeze(1))
|
| 135 |
+
fake_scores, fake_fmaps = self.discriminator(generated.unsqueeze(1))
|
| 136 |
|
| 137 |
# Reconstruction Loss
|
| 138 |
loss_recon = self.loss_recon(generated, target)
|
|
|
|
| 163 |
sch_g, sch_d = self.lr_schedulers()
|
| 164 |
if sch_g: sch_g.step()
|
| 165 |
if sch_d: sch_d.step()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 166 |
|
| 167 |
def configure_optimizers(self):
|
| 168 |
# Generator Optimizer
|
|
|
|
| 185 |
|
| 186 |
def main():
|
| 187 |
parser = argparse.ArgumentParser(description="Train a Music Source Restoration Model")
|
| 188 |
+
parser.add_argument("--config", type=str, required=True, help="Path to the config file.")
|
| 189 |
args = parser.parse_args()
|
| 190 |
|
| 191 |
with open(args.config, 'r') as f:
|
|
|
|
| 195 |
|
| 196 |
data_module = MusicRestorationDataModule(config['data'])
|
| 197 |
model_module = MusicRestorationModule(config)
|
| 198 |
+
|
| 199 |
+
exp_name = f"{config['model']['name']}"
|
| 200 |
+
exp_name = exp_name.replace(" ", "_")
|
| 201 |
+
save_dir = Path(config['trainer']['save_dir']) / config['project_name'] / exp_name
|
| 202 |
|
| 203 |
# Callbacks
|
| 204 |
checkpoint_callback = ModelCheckpoint(
|
| 205 |
dirpath=save_dir / "checkpoints",
|
| 206 |
+
filename="{step:08d}",
|
| 207 |
+
every_n_train_steps=config['trainer']['checkpoint_save_interval'],
|
| 208 |
+
save_top_k=-1,
|
| 209 |
auto_insert_metric_name=False
|
| 210 |
)
|
| 211 |
lr_monitor = LearningRateMonitor(logging_interval='step')
|
| 212 |
|
| 213 |
# Logger
|
| 214 |
logger = TensorBoardLogger(
|
| 215 |
+
save_dir=config['trainer']['save_dir'],
|
| 216 |
name=config['project_name'],
|
| 217 |
+
version=exp_name
|
| 218 |
)
|
| 219 |
|
| 220 |
# Trainer
|
|
|
|
| 222 |
logger=logger,
|
| 223 |
callbacks=[checkpoint_callback, lr_monitor],
|
| 224 |
max_steps=config['trainer']['max_steps'],
|
|
|
|
| 225 |
log_every_n_steps=config['trainer']['log_every_n_steps'],
|
| 226 |
devices=config['trainer']['devices'],
|
| 227 |
precision=config['trainer']['precision'],
|
| 228 |
+
accelerator="gpu"
|
| 229 |
)
|
| 230 |
|
| 231 |
trainer.fit(model_module, datamodule=data_module)
|
unwrap.py
CHANGED
|
@@ -1,5 +1,5 @@
|
|
| 1 |
import torch
|
| 2 |
-
import
|
| 3 |
from collections import OrderedDict
|
| 4 |
from pathlib import Path
|
| 5 |
|
|
@@ -33,22 +33,9 @@ def unwrap_generator_checkpoint(ckpt_path: str, output_path: str) -> None:
|
|
| 33 |
torch.save(generator_state_dict, output_path)
|
| 34 |
|
| 35 |
if __name__ == '__main__':
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
)
|
| 39 |
-
|
| 40 |
-
'
|
| 41 |
-
|
| 42 |
-
required=True,
|
| 43 |
-
help="Path to the input PyTorch Lightning checkpoint file (.ckpt)."
|
| 44 |
-
)
|
| 45 |
-
parser.add_argument(
|
| 46 |
-
'--out',
|
| 47 |
-
type=str,
|
| 48 |
-
required=True,
|
| 49 |
-
help="Path to save the unwrapped generator weights (.pth)."
|
| 50 |
-
)
|
| 51 |
-
|
| 52 |
-
args = parser.parse_args()
|
| 53 |
-
|
| 54 |
-
unwrap_generator_checkpoint(args.ckpt, args.out)
|
|
|
|
| 1 |
import torch
|
| 2 |
+
import os, glob
|
| 3 |
from collections import OrderedDict
|
| 4 |
from pathlib import Path
|
| 5 |
|
|
|
|
| 33 |
torch.save(generator_state_dict, output_path)
|
| 34 |
|
| 35 |
if __name__ == '__main__':
|
| 36 |
+
input_dir = "/root/autodl-tmp/checkpoints/mel-unet"
|
| 37 |
+
# find all .ckpt files in the input directory
|
| 38 |
+
ckpt_files = glob.glob(os.path.join(input_dir, '*.ckpt'))
|
| 39 |
+
for ckpt_file in ckpt_files:
|
| 40 |
+
unwrap_generator_checkpoint(ckpt_file, os.path.join(input_dir, os.path.basename(ckpt_file).replace('.ckpt', '.pth')))
|
| 41 |
+
print(f"Unwrapped {ckpt_file} to {os.path.join(input_dir, os.path.basename(ckpt_file).replace('.ckpt', '.pth'))}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|