yongyizang commited on
Commit
2b2771c
·
1 Parent(s): fccca85

update scripts

Browse files
.gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
README.md CHANGED
@@ -21,6 +21,34 @@ The repository is organized to separate concerns, making it easy to extend and m
21
  - `discriminator/` <- Discriminator architectures
22
  - `generator/` <- Reusable generator components
23
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  ## 🚀 Getting Started
25
 
26
  ### 1. Setup
@@ -43,8 +71,6 @@ Key sections to update:
43
 
44
  - `data.train_dataset.root_directory`: Path to your training data.
45
  - `data.train_dataset.file_list`: Path to a `.txt` file listing your training samples.
46
- - `data.val_dataset.root_directory`: Path to your validation data.
47
- - `data.val_dataset.file_list`: Path to a `.txt` file listing your validation samples.
48
  - `model`: Choose the generator model and its parameters.
49
  - `discriminators`: Add and configure one or more discriminators.
50
  - `trainer`: Set training parameters like `max_steps`, `devices` (GPU IDs), and `precision`.
@@ -57,8 +83,6 @@ Launch the training process using the `train.py` script and your configuration f
57
  python train.py --config config.yaml
58
  ```
59
 
60
- Logs, checkpoints, and audio samples will be saved in the `lightning_logs/` directory.
61
-
62
  ### 4. Unwrap Generator Weights
63
 
64
  After training, you may want to use the generator model for inference without the rest of the Lightning module. The `unwrap.py` script extracts the generator's `state_dict` from a checkpoint file.
 
21
  - `discriminator/` <- Discriminator architectures
22
  - `generator/` <- Reusable generator components
23
 
24
+ ## Run Inference On The Pretrained Models
25
+
26
+ Download from https://huggingface.co/yongyizang/MSRChallengeBaseline, then run `inference.py` to evaluate the pretrained models.
27
+
28
+ ```bash
29
+ python inference.py --config config.yaml --checkpoint path/to/your/checkpoint.ckpt --input_dir path/to/your/input/directory --output_dir path/to/your/output/directory
30
+ ```
31
+
32
+ Every `*.flac` file in the `input_dir` will be processed and saved in the `output_dir`.
33
+
34
+ ## Evaluation Script
35
+
36
+ Evaluation script is provided in the `calculate_metrics.py` file.
37
+
38
+ ```bash
39
+ python calculate_metrics.py {file list}
40
+ ```
41
+
42
+ The evaluation script is expecting a file list with each line in the format of `{target path}|{output path}`. Results will be printed to the console; you can use ` .. > output.txt` to redirect the output to a file.
43
+
44
+ We recommend modifying this script to fit your needs.
45
+
46
+ ---
47
+
48
+ For a comprehensive list of arguments, please check each individual script.
49
+
50
+ ---
51
+
52
  ## 🚀 Getting Started
53
 
54
  ### 1. Setup
 
71
 
72
  - `data.train_dataset.root_directory`: Path to your training data.
73
  - `data.train_dataset.file_list`: Path to a `.txt` file listing your training samples.
 
 
74
  - `model`: Choose the generator model and its parameters.
75
  - `discriminators`: Add and configure one or more discriminators.
76
  - `trainer`: Set training parameters like `max_steps`, `devices` (GPU IDs), and `precision`.
 
83
  python train.py --config config.yaml
84
  ```
85
 
 
 
86
  ### 4. Unwrap Generator Weights
87
 
88
  After training, you may want to use the generator model for inference without the rest of the Lightning module. The `unwrap.py` script extracts the generator's `state_dict` from a checkpoint file.
calculate_metrics.py ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import soundfile as sf
3
+ import torch
4
+ from torchmetrics.audio import ScaleInvariantSignalNoiseRatio
5
+ import argparse
6
+ import numpy as np
7
+ import warnings
8
+ from scipy.linalg import sqrtm
9
+ from tqdm import tqdm
10
+
11
+ warnings.filterwarnings("ignore")
12
+
13
+ try:
14
+ from transformers import ClapModel, ClapProcessor
15
+ except ImportError:
16
+ print("Error: The 'transformers' library is not installed.")
17
+ print("Please install it to run FAD-CLAP calculations:")
18
+ print("pip install torch transformers")
19
+ exit(1)
20
+
21
+
22
+ def load_audio(file_path, sr=48000):
23
+ try:
24
+ wav, samplerate = sf.read(file_path)
25
+ if samplerate != sr:
26
+ pass
27
+ if wav.ndim > 1:
28
+ wav = wav.T
29
+ else:
30
+ wav = wav[np.newaxis, :]
31
+ return torch.from_numpy(wav).float()
32
+ except Exception:
33
+ return None
34
+
35
+ def get_clap_embeddings(file_paths, model, processor, device, batch_size=16):
36
+ model.to(device)
37
+ all_embeddings = []
38
+
39
+ for i in tqdm(range(0, len(file_paths), batch_size), desc=" Calculating embeddings", ncols=100, leave=False):
40
+ batch_paths = file_paths[i:i+batch_size]
41
+ audio_batch = []
42
+ for path in batch_paths:
43
+ try:
44
+ wav, sr = sf.read(path)
45
+ if wav.ndim == 2 and wav.shape[1] == 2:
46
+ audio_batch.append(wav[:, 0]) # Left channel
47
+ audio_batch.append(wav[:, 1]) # Right channel
48
+ elif wav.ndim == 1:
49
+ audio_batch.append(wav)
50
+ else:
51
+ continue
52
+ except Exception:
53
+ continue
54
+
55
+ if not audio_batch:
56
+ continue
57
+
58
+ try:
59
+ inputs = processor(audios=audio_batch, sampling_rate=48000, return_tensors="pt", padding=True)
60
+ inputs = {key: val.to(device) for key, val in inputs.items()}
61
+
62
+ with torch.no_grad():
63
+ audio_features = model.get_audio_features(**inputs)
64
+
65
+ all_embeddings.append(audio_features.cpu().numpy())
66
+ except Exception:
67
+ continue
68
+
69
+ if not all_embeddings:
70
+ return np.array([])
71
+
72
+ return np.concatenate(all_embeddings, axis=0)
73
+
74
+ def calculate_frechet_distance(embeddings1, embeddings2):
75
+ if embeddings1.shape[0] < 2 or embeddings2.shape[0] < 2:
76
+ return None
77
+
78
+ mu1, mu2 = np.mean(embeddings1, axis=0), np.mean(embeddings2, axis=0)
79
+ sigma1, sigma2 = np.cov(embeddings1, rowvar=False), np.cov(embeddings2, rowvar=False)
80
+
81
+ ssdiff = np.sum((mu1 - mu2)**2.0)
82
+
83
+ try:
84
+ covmean, _ = sqrtm(sigma1.dot(sigma2), disp=False)
85
+ except Exception:
86
+ return None
87
+
88
+ if np.iscomplexobj(covmean):
89
+ covmean = covmean.real
90
+
91
+ fad_score = ssdiff + np.trace(sigma1 + sigma2 - 2.0 * covmean)
92
+ return fad_score
93
+
94
+ def main():
95
+ parser = argparse.ArgumentParser(description="Calculate SI-SNR and FAD-CLAP for audio pairs listed in a text file.")
96
+ parser.add_argument("file_list", type=str, help="Path to a text file with the format: target_path|output_path")
97
+ parser.add_argument("--batch_size", type=int, default=16, help="Batch size for FAD-CLAP embedding calculation.")
98
+ args = parser.parse_args()
99
+
100
+ if not os.path.exists(args.file_list):
101
+ print(f"Error: Input file not found at {args.file_list}")
102
+ return
103
+
104
+ sisnr_calculator = ScaleInvariantSignalNoiseRatio()
105
+ all_target_paths = []
106
+ all_output_paths = []
107
+
108
+ print("--- Calculating SI-SNR for each pair ---")
109
+ with open(args.file_list, 'r') as f:
110
+ for line in f:
111
+ line = line.strip()
112
+ if not line or '|' not in line:
113
+ continue
114
+
115
+ try:
116
+ target_path, output_path = [p.strip() for p in line.split('|')]
117
+
118
+ if not os.path.exists(target_path) or not os.path.exists(output_path):
119
+ print(f"Skipping line, file not found: {line}")
120
+ continue
121
+
122
+ target_wav = load_audio(target_path)
123
+ output_wav = load_audio(output_path)
124
+
125
+ if target_wav is None or output_wav is None:
126
+ continue
127
+ if target_wav.shape[0] != output_wav.shape[0]:
128
+ continue
129
+
130
+ min_len = min(target_wav.shape[-1], output_wav.shape[-1])
131
+ target_wav = target_wav[..., :min_len]
132
+ output_wav = output_wav[..., :min_len]
133
+
134
+ if target_wav.shape[-1] == 0:
135
+ continue
136
+
137
+ sisnr_val = sisnr_calculator(output_wav, target_wav)
138
+ print(f"{target_path}|{output_path}|{sisnr_val.item():.4f}")
139
+
140
+ all_target_paths.append(target_path)
141
+ all_output_paths.append(output_path)
142
+
143
+ except Exception:
144
+ continue
145
+
146
+ print("\n--- Calculating FAD-CLAP for all target vs. all output files ---")
147
+ if not all_target_paths:
148
+ print("No valid file pairs found to calculate FAD-CLAP.")
149
+ return
150
+
151
+ try:
152
+ print("Loading CLAP model...")
153
+ clap_model = ClapModel.from_pretrained("laion/clap-htsat-unfused")
154
+ clap_processor = ClapProcessor.from_pretrained("laion/clap-htsat-unfused")
155
+ clap_model.eval()
156
+ print("CLAP model loaded successfully.")
157
+ except Exception as e:
158
+ print(f"Fatal Error: Could not load CLAP model. Please check internet connection. Error: {e}")
159
+ return
160
+
161
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
162
+ print(f"Using device: {device}")
163
+
164
+ print("\nCalculating embeddings for all target files...")
165
+ target_embeddings = get_clap_embeddings(all_target_paths, clap_model, clap_processor, device, args.batch_size)
166
+
167
+ print("Calculating embeddings for all output files...")
168
+ output_embeddings = get_clap_embeddings(all_output_paths, clap_model, clap_processor, device, args.batch_size)
169
+
170
+ if target_embeddings.size > 0 and output_embeddings.size > 0:
171
+ print("Calculating Frechet Audio Distance (FAD)...")
172
+ fad_score = calculate_frechet_distance(target_embeddings, output_embeddings)
173
+ if fad_score is not None:
174
+ print(f"\nOverall FAD-CLAP Score: {fad_score:.4f}")
175
+ else:
176
+ print("\nCould not calculate FAD-CLAP score.")
177
+ else:
178
+ print("\nCould not calculate FAD-CLAP due to issues with embedding generation.")
179
+
180
+ if __name__ == "__main__":
181
+ main()
182
+
config.yaml CHANGED
@@ -9,32 +9,24 @@ model:
9
  num_heads: 4
10
  window_size: 2048
11
  hop_size: 512
12
- sample_rate: 44100
13
 
14
  discriminators:
15
  - name: "MultiFrequencyDiscriminator"
16
  params:
17
  nch: 1
18
  window_sizes: [2048, 1024, 512]
19
- sample_rate: 44100
20
  norm: True
21
- # you can add more discriminators here
22
 
23
  data:
24
- sample_rate: 44100
25
  clip_duration: 3.0
26
  train_dataset:
27
  target_stem: "Voc"
28
- root_directory: "/path/to/your/training/data"
29
- file_list: "/path/to/your/train_split.txt"
30
  apply_augmentation: True
31
  snr_range: [0.0, 10.0]
32
- val_dataset:
33
- target_stem: "Voc"
34
- root_directory: "/path/to/your/validation/data"
35
- file_list: "/path/to/your/val_split.txt"
36
- apply_augmentation: True
37
- snr_range: [5.0, 5.0] # Fixed SNR for validation
38
  dataloader_params:
39
  batch_size: 4
40
  num_workers: 8
@@ -56,15 +48,14 @@ losses:
56
  lambda_feat: 2.0
57
  lambda_gan: 1.0
58
  reconstruction_loss:
59
- sample_rate: 44100
60
  n_fft: [1024, 2048, 512]
61
  hop_length: [256, 512, 128]
62
  n_mels: [80, 160, 40]
63
 
64
  trainer:
65
  max_steps: 1000000
66
- val_check_interval: 5000
67
  log_every_n_steps: 100
68
- devices: [0] # List of GPU IDs to use
69
- precision: bf16-mixed
70
- log_media_every_n_steps: 5000
 
9
  num_heads: 4
10
  window_size: 2048
11
  hop_size: 512
12
+ sample_rate: 48000
13
 
14
  discriminators:
15
  - name: "MultiFrequencyDiscriminator"
16
  params:
17
  nch: 1
18
  window_sizes: [2048, 1024, 512]
19
+ sample_rate: 48000
20
  norm: True
 
21
 
22
  data:
23
+ sample_rate: 48000
24
  clip_duration: 3.0
25
  train_dataset:
26
  target_stem: "Voc"
27
+ root_directory: "/path/to/your/training/data/dir"
 
28
  apply_augmentation: True
29
  snr_range: [0.0, 10.0]
 
 
 
 
 
 
30
  dataloader_params:
31
  batch_size: 4
32
  num_workers: 8
 
48
  lambda_feat: 2.0
49
  lambda_gan: 1.0
50
  reconstruction_loss:
51
+ sample_rate: 48000
52
  n_fft: [1024, 2048, 512]
53
  hop_length: [256, 512, 128]
54
  n_mels: [80, 160, 40]
55
 
56
  trainer:
57
  max_steps: 1000000
 
58
  log_every_n_steps: 100
59
+ checkpoint_save_interval: 10000
60
+ devices: [0]
61
+ precision: bf16-mixed
data/augment.py CHANGED
@@ -1,5 +1,5 @@
1
  import numpy as np
2
- from eq_utils import apply_random_eq
3
  from pedalboard import Pedalboard, Resample, Compressor, Distortion, Reverb, Limiter, MP3Compressor
4
 
5
  def fix_length_to_duration(target: np.ndarray, duration: float) -> np.ndarray:
 
1
  import numpy as np
2
+ from data.eq_utils import apply_random_eq
3
  from pedalboard import Pedalboard, Resample, Compressor, Distortion, Reverb, Limiter, MP3Compressor
4
 
5
  def fix_length_to_duration(target: np.ndarray, duration: float) -> np.ndarray:
data/dataset.py CHANGED
@@ -8,7 +8,7 @@ import json
8
  from typing import List, Optional, Dict, Union, Tuple, Any
9
  from torch.utils.data import Dataset, Sampler
10
  from tqdm import tqdm
11
- from augment import StemAugmentation, MixtureAugmentation
12
 
13
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
14
  logger = logging.getLogger(__name__)
@@ -84,8 +84,7 @@ class RawStems(Dataset):
84
  self,
85
  target_stem: str,
86
  root_directory: Union[str, Path],
87
- file_list: Union[str, Path],
88
- sr: int = 44100,
89
  clip_duration: float = 3.0,
90
  snr_range: Tuple[float, float] = (0.0, 10.0),
91
  apply_augmentation: bool = True,
@@ -97,18 +96,27 @@ class RawStems(Dataset):
97
  self.snr_range = snr_range
98
  self.apply_augmentation = apply_augmentation
99
  self.rms_threshold = rms_threshold
100
-
101
- self.folders = []
102
- with open(file_list, 'r') as f:
103
- for line in f:
104
- folder = self.root_directory / Path(line.strip())
105
- if folder.exists(): self.folders.append(folder)
106
- else: logger.warning(f"Folder does not exist: {folder}")
107
 
108
  target_stem_parts = target_stem.split("_")
109
  self.target_stem_1 = target_stem_parts[0].strip()
110
  self.target_stem_2 = target_stem_parts[1].strip() if len(target_stem_parts) > 1 else None
111
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
112
  self.audio_files = self._index_audio_files()
113
  if not self.audio_files: raise ValueError("No audio files found.")
114
 
@@ -150,12 +158,10 @@ class RawStems(Dataset):
150
  activity_masks[path_str] = np.array([False] * len(rms_values))
151
  continue
152
 
153
- # Efficiently check if the average RMS in a sliding window is above the threshold
154
  is_loud = rms_values > self.rms_threshold
155
  sum_loud = np.convolve(is_loud, np.ones(window_size), 'valid')
156
- avg_loud_enough = sum_loud / window_size > 0.8 # At least 80% of seconds must be loud
157
 
158
- # Pad the mask to match the original length of rms_values
159
  mask = np.zeros(len(rms_values), dtype=bool)
160
  mask[:len(avg_loud_enough)] = avg_loud_enough
161
  activity_masks[path_str] = mask
@@ -171,13 +177,12 @@ class RawStems(Dataset):
171
  for file_path in file_paths:
172
  path_str = str(file_path.relative_to(self.root_directory))
173
  mask = self.activity_masks.get(path_str)
174
- if mask is None: return [] # This file has no mask, combination is invalid
175
  masks_to_intersect.append(mask)
176
  min_len = min(min_len, len(mask))
177
 
178
  if not masks_to_intersect: return []
179
 
180
- # Truncate all masks to the minimum length and intersect
181
  final_mask = np.ones(min_len, dtype=bool)
182
  for mask in masks_to_intersect:
183
  final_mask &= mask[:min_len]
@@ -204,7 +209,7 @@ class RawStems(Dataset):
204
  if not is_target:
205
  song_dict["others"].append(p)
206
  except ValueError:
207
- continue # Should not happen if p is from folder.rglob
208
 
209
  if song_dict["target_stems"] and song_dict["others"]:
210
  indexed_songs.append(song_dict)
@@ -226,12 +231,11 @@ class RawStems(Dataset):
226
  start_second = random.choice(valid_starts)
227
  offset = start_second + random.uniform(0, 1.0 - (self.clip_duration % 1.0 or 1.0))
228
 
229
- # --- Audio Loading and Mixing ---
230
  target_mix = sum(load_audio(p, offset, self.clip_duration, self.sr) for p in selected_targets) / num_targets
231
  other_mix = sum(load_audio(p, offset, self.clip_duration, self.sr) for p in selected_others) / num_others
232
 
233
  if not contains_audio_signal(target_mix) or not contains_audio_signal(other_mix):
234
- continue # Should be rare now, but as a safeguard
235
 
236
  target_clean = target_mix.copy()
237
  target_augmented = self.stem_augmentation.apply(target_mix, self.sr) if self.apply_augmentation else target_mix
@@ -243,16 +247,28 @@ class RawStems(Dataset):
243
 
244
  mixture_augmented = self.mixture_augmentation.apply(mixture, self.sr) if self.apply_augmentation else mixture
245
 
246
- # --- Normalization and final prep ---
247
  max_val = np.max(np.abs(mixture_augmented)) + 1e-8
248
  mixture_final = mixture_augmented / max_val
249
  target_final = target_clean / max_val
250
 
251
  rescale = np.random.uniform(*DEFAULT_GAIN_RANGE)
 
 
 
 
 
 
 
 
 
 
 
 
 
252
 
253
  return {
254
- "mixture": np.nan_to_num(mixture_final * rescale),
255
- "target": np.nan_to_num(target_final * rescale)
256
  }
257
 
258
  return self.__getitem__(random.randint(0, len(self.audio_files) - 1))
@@ -275,35 +291,4 @@ class InfiniteSampler(Sampler):
275
  while True:
276
  if self.pointer >= self.dataset_size: self.reset()
277
  yield self.indexes[self.pointer]
278
- self.pointer += 1
279
-
280
- if __name__ == "__main__":
281
- root = "/lan/ifc/downloaded_datasets/cambridge-mt/sorted_files"
282
- dataset = RawStems(
283
- target_stem="Voc",
284
- root_directory=root,
285
- file_list="/home/yongyizang/music_source_restoration/configs/data_split/Voc_train.txt",
286
- sr=44100,
287
- clip_duration=10.0,
288
- apply_augmentation=True,
289
- rms_threshold=-30.0
290
- )
291
-
292
- sampler = InfiniteSampler(dataset)
293
- iterator = iter(sampler)
294
-
295
- output_dir = Path("./msr_test_set/Voc/")
296
- output_dir.mkdir(parents=True, exist_ok=True)
297
- logger.info(f"Output directory: {output_dir}")
298
-
299
- for i in tqdm(range(10), desc="Generating test samples"):
300
- index = next(iterator)
301
- sample = dataset[index]
302
-
303
- mixture_path = output_dir / f"mixture_{i}.wav"
304
- target_path = output_dir / f"target_{i}.wav"
305
-
306
- sf.write(mixture_path, sample["mixture"].T, dataset.sr)
307
- sf.write(target_path, sample["target"].T, dataset.sr)
308
-
309
- print("Test complete.")
 
8
  from typing import List, Optional, Dict, Union, Tuple, Any
9
  from torch.utils.data import Dataset, Sampler
10
  from tqdm import tqdm
11
+ from data.augment import StemAugmentation, MixtureAugmentation
12
 
13
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
14
  logger = logging.getLogger(__name__)
 
84
  self,
85
  target_stem: str,
86
  root_directory: Union[str, Path],
87
+ sr: int = 48000,
 
88
  clip_duration: float = 3.0,
89
  snr_range: Tuple[float, float] = (0.0, 10.0),
90
  apply_augmentation: bool = True,
 
96
  self.snr_range = snr_range
97
  self.apply_augmentation = apply_augmentation
98
  self.rms_threshold = rms_threshold
 
 
 
 
 
 
 
99
 
100
  target_stem_parts = target_stem.split("_")
101
  self.target_stem_1 = target_stem_parts[0].strip()
102
  self.target_stem_2 = target_stem_parts[1].strip() if len(target_stem_parts) > 1 else None
103
 
104
+ logger.info(f"Scanning '{self.root_directory}' for songs containing stem '{target_stem}'...")
105
+ self.folders = []
106
+ for song_dir in self.root_directory.iterdir():
107
+ if song_dir.is_dir():
108
+ target_path = song_dir / self.target_stem_1
109
+ if self.target_stem_2:
110
+ target_path /= self.target_stem_2
111
+
112
+ if target_path.exists() and target_path.is_dir():
113
+ self.folders.append(song_dir)
114
+
115
+ if not self.folders:
116
+ raise FileNotFoundError(f"No subdirectories in '{self.root_directory}' were found containing the stem path '{target_stem}'. "
117
+ f"Please check your directory structure.")
118
+ logger.info(f"Found {len(self.folders)} song folders.")
119
+
120
  self.audio_files = self._index_audio_files()
121
  if not self.audio_files: raise ValueError("No audio files found.")
122
 
 
158
  activity_masks[path_str] = np.array([False] * len(rms_values))
159
  continue
160
 
 
161
  is_loud = rms_values > self.rms_threshold
162
  sum_loud = np.convolve(is_loud, np.ones(window_size), 'valid')
163
+ avg_loud_enough = sum_loud / window_size > 0.8
164
 
 
165
  mask = np.zeros(len(rms_values), dtype=bool)
166
  mask[:len(avg_loud_enough)] = avg_loud_enough
167
  activity_masks[path_str] = mask
 
177
  for file_path in file_paths:
178
  path_str = str(file_path.relative_to(self.root_directory))
179
  mask = self.activity_masks.get(path_str)
180
+ if mask is None: return []
181
  masks_to_intersect.append(mask)
182
  min_len = min(min_len, len(mask))
183
 
184
  if not masks_to_intersect: return []
185
 
 
186
  final_mask = np.ones(min_len, dtype=bool)
187
  for mask in masks_to_intersect:
188
  final_mask &= mask[:min_len]
 
209
  if not is_target:
210
  song_dict["others"].append(p)
211
  except ValueError:
212
+ continue
213
 
214
  if song_dict["target_stems"] and song_dict["others"]:
215
  indexed_songs.append(song_dict)
 
231
  start_second = random.choice(valid_starts)
232
  offset = start_second + random.uniform(0, 1.0 - (self.clip_duration % 1.0 or 1.0))
233
 
 
234
  target_mix = sum(load_audio(p, offset, self.clip_duration, self.sr) for p in selected_targets) / num_targets
235
  other_mix = sum(load_audio(p, offset, self.clip_duration, self.sr) for p in selected_others) / num_others
236
 
237
  if not contains_audio_signal(target_mix) or not contains_audio_signal(other_mix):
238
+ continue
239
 
240
  target_clean = target_mix.copy()
241
  target_augmented = self.stem_augmentation.apply(target_mix, self.sr) if self.apply_augmentation else target_mix
 
247
 
248
  mixture_augmented = self.mixture_augmentation.apply(mixture, self.sr) if self.apply_augmentation else mixture
249
 
 
250
  max_val = np.max(np.abs(mixture_augmented)) + 1e-8
251
  mixture_final = mixture_augmented / max_val
252
  target_final = target_clean / max_val
253
 
254
  rescale = np.random.uniform(*DEFAULT_GAIN_RANGE)
255
+
256
+ mixture = np.nan_to_num(mixture_final * rescale)
257
+ target = np.nan_to_num(target_final * rescale)
258
+
259
+ target_length = int(self.clip_duration * self.sr)
260
+ if target.shape[1] != target_length:
261
+ target = np.pad(target, (0, target_length - target.shape[1]), mode='constant')
262
+ else:
263
+ target = target[:, :target_length]
264
+ if mixture.shape[1] != target_length:
265
+ mixture = np.pad(mixture, (0, target_length - mixture.shape[1]), mode='constant')
266
+ else:
267
+ mixture = mixture[:, :target_length]
268
 
269
  return {
270
+ "mixture": np.nan_to_num(mixture),
271
+ "target": np.nan_to_num(target)
272
  }
273
 
274
  return self.__getitem__(random.randint(0, len(self.audio_files) - 1))
 
291
  while True:
292
  if self.pointer >= self.dataset_size: self.reset()
293
  yield self.indexes[self.pointer]
294
+ self.pointer += 1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
evaluation/README.md DELETED
@@ -1,31 +0,0 @@
1
- # Evaluation Module
2
-
3
- This directory contains classes for evaluating model performance during validation. All metrics inherit from a base `Metric` class for a consistent interface.
4
-
5
- ## Files
6
-
7
- ### `metrics.py`
8
-
9
- #### `SI_SNR` (Scale-Invariant Signal-to-Noise Ratio)
10
-
11
- A common metric for audio source separation that measures the quality of the restored signal relative to the original target. It is invariant to the overall scaling of the estimated signal.
12
-
13
- - `update(pred, target)`: Updates the running statistics with a new batch of predicted and target audio tensors.
14
- - `compute()`: Calculates the mean and standard deviation of the SI-SNR scores accumulated since the last reset.
15
- - `reset()`: Clears the accumulated statistics.
16
-
17
- #### `FAD_CLAP` (Fréchet Audio Distance using CLAP)
18
-
19
- Measures the Fréchet distance between the distributions of embeddings from the generated audio and the ground truth audio. It uses a pre-trained CLAP (Contrastive Language-Audio Pretraining) model to generate these embeddings, providing a perceptually relevant measure of audio quality and similarity.
20
-
21
- **Note:** This metric requires the `laion-clap` library. If not installed, it will fall back to using random embeddings, which is not meaningful for evaluation.
22
-
23
- - `update(pred, target)`: Extracts CLAP embeddings from the predicted and target audio tensors and stores them.
24
- - `compute()`: Calculates the FAD score between the collected sets of embeddings.
25
- - `reset()`: Clears the stored embeddings.
26
-
27
- **`__init__` Arguments:**
28
-
29
- - `embedding_dim` (`int`): The dimensionality of the embeddings. Should match the CLAP model. Default: `512`.
30
- - `model_name` (`str`): The name of the CLAP model architecture to use. Default: `'HTSAT-base'`.
31
- - `ckpt_path` (`Optional[str]`): Optional path to a specific CLAP model checkpoint. If `None`, it uses the default pre-trained weights.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
evaluation/__init__.py DELETED
File without changes
evaluation/metrics.py DELETED
@@ -1,183 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
- import logging
4
- from typing import Dict, List, Optional, Any, Tuple
5
- from abc import ABC, abstractmethod
6
-
7
- try:
8
- import laion_clap
9
- except ImportError:
10
- raise ImportError(
11
- "The `laion_clap` package is required for the FAD metric. "
12
- "Please install it with: pip install laion-clap"
13
- )
14
-
15
- class Metric(nn.Module, ABC):
16
- def __init__(self):
17
- super().__init__()
18
- self.register_buffer("dummy_buffer", torch.empty(0))
19
-
20
- @property
21
- def device(self) -> torch.device:
22
- return self.dummy_buffer.device
23
-
24
- @abstractmethod
25
- def reset(self):
26
- raise NotImplementedError
27
-
28
- @abstractmethod
29
- def update(self, *args: Any, **kwargs: Any):
30
- raise NotImplementedError
31
-
32
- @abstractmethod
33
- def compute(self) -> Dict[str, float]:
34
- raise NotImplementedError
35
-
36
- class SI_SNR(Metric):
37
- def __init__(self, eps: float = 1e-8):
38
- super().__init__()
39
- self.eps = eps
40
- self.reset()
41
-
42
- def reset(self):
43
- self.register_buffer("sum_scores", torch.tensor(0.0, dtype=torch.float64))
44
- self.register_buffer("sum_sq_scores", torch.tensor(0.0, dtype=torch.float64))
45
- self.register_buffer("count", torch.tensor(0, dtype=torch.int64))
46
-
47
- def update(self, pred: torch.Tensor, target: torch.Tensor):
48
- score = self._compute_si_snr(pred, target).detach()
49
- self.sum_scores += torch.sum(score)
50
- self.sum_sq_scores += torch.sum(score.pow(2))
51
- self.count += score.numel()
52
-
53
- def compute(self) -> Dict[str, float]:
54
- if self.count.item() == 0:
55
- return {'mean': 0.0, 'std': 0.0, 'count': 0}
56
-
57
- total_count = self.count.item()
58
- mean_val = (self.sum_scores / self.count).item()
59
- var = (self.sum_sq_scores / self.count) - (self.sum_scores / self.count).pow(2)
60
- std_val = torch.sqrt(var).item() if var > 0 and total_count > 1 else 0.0
61
-
62
- return {'mean': mean_val, 'std': std_val, 'count': int(total_count)}
63
-
64
- def _compute_si_snr(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
65
- pred = pred.view(-1, pred.shape[-1])
66
- target = target.view(-1, target.shape[-1])
67
- pred_zm = pred - pred.mean(dim=-1, keepdim=True)
68
- target_zm = target - target.mean(dim=-1, keepdim=True)
69
- alpha = (pred_zm * target_zm).sum(dim=-1, keepdim=True) / \
70
- (target_zm.pow(2).sum(dim=-1, keepdim=True) + self.eps)
71
- target_scaled = alpha * target_zm
72
- noise = pred_zm - target_scaled
73
- si_snr_val = (target_scaled.pow(2).sum(dim=-1)) / \
74
- (noise.pow(2).sum(dim=-1) + self.eps)
75
- return 10 * torch.log10(si_snr_val + self.eps)
76
-
77
- class FAD_CLAP(Metric):
78
- def __init__(self, embedding_dim: int = 512, model_name: str = 'HTSAT-base', ckpt_path: Optional[str] = None):
79
- super().__init__()
80
- self.embedding_dim = embedding_dim
81
- self.clap_model = self._load_clap_model(model_name, ckpt_path)
82
-
83
- self.pred_embeddings: List[torch.Tensor] = []
84
- self.target_embeddings: List[torch.Tensor] = []
85
- self.reset()
86
-
87
- def _load_clap_model(self, model_name: str, ckpt_path: Optional[str]) -> Optional[nn.Module]:
88
- if laion_clap is None:
89
- logging.warning("`laion_clap` is not installed. FAD will use random embeddings.")
90
- return None
91
- try:
92
- logging.info(f"Loading CLAP model '{model_name}' for FAD metric...")
93
- model = laion_clap.CLAP_Module(enable_fusion=False, amodel=model_name)
94
- model.load_ckpt(ckpt_path)
95
- model.eval()
96
- logging.info("CLAP model loaded successfully.")
97
- return model
98
- except Exception as e:
99
- logging.warning(f"Failed to load CLAP model due to an error: {e}. FAD will use random embeddings.")
100
- return None
101
-
102
- def to(self, *args, **kwargs):
103
- super().to(*args, **kwargs)
104
- return self
105
-
106
- def reset(self):
107
- self.pred_embeddings.clear()
108
- self.target_embeddings.clear()
109
-
110
- def update(self, pred: torch.Tensor, target: torch.Tensor):
111
- self.pred_embeddings.append(self._extract_embedding(pred).cpu())
112
- self.target_embeddings.append(self._extract_embedding(target).cpu())
113
-
114
- def compute(self) -> Dict[str, float]:
115
- if not self.pred_embeddings or not self.target_embeddings:
116
- return {'fad': float('inf'), 'count': 0}
117
-
118
- pred_emb_all = torch.cat(self.pred_embeddings, dim=0).to(self.device)
119
- target_emb_all = torch.cat(self.target_embeddings, dim=0).to(self.device)
120
-
121
- if pred_emb_all.shape[0] < 2 or target_emb_all.shape[0] < 2:
122
- logging.warning(f"FAD requires at least 2 samples per set, but got {pred_emb_all.shape[0]} and {target_emb_all.shape[0]}.")
123
- return {'fad': float('inf'), 'count': pred_emb_all.shape[0]}
124
-
125
- mu_pred, sigma_pred = self._get_mu_and_sigma(pred_emb_all)
126
- mu_target, sigma_target = self._get_mu_and_sigma(target_emb_all)
127
- fad_score = self._frechet_distance(mu_pred, sigma_pred, mu_target, sigma_target)
128
- return {'fad': fad_score.item(), 'count': len(pred_emb_all)}
129
-
130
- @torch.no_grad()
131
- def _extract_embedding(self, audio: torch.Tensor) -> torch.Tensor:
132
- if self.clap_model is None:
133
- return torch.randn(audio.shape[0], self.embedding_dim, device=audio.device)
134
-
135
- self.clap_model.to(audio.device)
136
-
137
- audio_dict = {'waveform': audio, 'sample_rate': 48000}
138
- return self.clap_model.get_audio_embedding_from_data(x=audio_dict, use_tensor=True)
139
-
140
- def _get_mu_and_sigma(self, embeddings: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
141
- mu = embeddings.mean(dim=0)
142
- sigma = torch.cov(embeddings.T)
143
- return mu, sigma
144
-
145
- def _frechet_distance(self, mu1, sigma1, mu2, sigma2) -> torch.Tensor:
146
- diff = mu1 - mu2
147
- mean_dist_sq = diff.dot(diff)
148
- try:
149
- offset = torch.eye(sigma1.shape[0], device=self.device, dtype=sigma1.dtype) * 1e-6
150
- cov_sqrt = torch.linalg.sqrtm((sigma1 + offset) @ (sigma2 + offset)).real
151
- except RuntimeError:
152
- logging.warning("Matrix square root failed. Using diagonal approximation for FAD.")
153
- cov_sqrt = torch.sqrt(torch.diag(sigma1) * torch.diag(sigma2))
154
- trace_term = torch.trace(sigma1) + torch.trace(sigma2) - 2 * torch.trace(cov_sqrt)
155
- return mean_dist_sq + trace_term
156
-
157
- if __name__ == '__main__':
158
- device = "cuda" if torch.cuda.is_available() else "cpu"
159
- print("Initializing FAD metric...")
160
- fad_metric = FAD_CLAP()
161
- fad_metric.to(device)
162
-
163
- sample_rate = 48000
164
- dummy_pred_audio_batch1 = torch.randn(4, sample_rate * 2, device=device)
165
- dummy_target_audio_batch1 = torch.randn(4, sample_rate * 2, device=device)
166
-
167
- dummy_pred_audio_batch2 = torch.randn(4, sample_rate * 2, device=device)
168
- dummy_target_audio_batch2 = torch.randn(4, sample_rate * 2, device=device)
169
-
170
- print("\nUpdating metric with batch 1...")
171
- fad_metric.update(pred=dummy_pred_audio_batch1, target=dummy_target_audio_batch1)
172
-
173
- print("Updating metric with batch 2...")
174
- fad_metric.update(pred=dummy_pred_audio_batch2, target=dummy_target_audio_batch2)
175
-
176
- print("\nComputing final FAD score...")
177
- final_fad_score = fad_metric.compute()
178
-
179
- print(f"Final FAD results: {final_fad_score}")
180
-
181
- fad_metric.reset()
182
- print("\nMetric has been reset.")
183
- print(f"State after reset: pred_embeddings={fad_metric.pred_embeddings}, target_embeddings={fad_metric.target_embeddings}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
inference.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import yaml
3
+ from pathlib import Path
4
+ from typing import Dict, Any
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import soundfile as sf
9
+ import numpy as np
10
+ from tqdm import tqdm
11
+
12
+ from models import MelRNN, MelRoFormer, UNet
13
+
14
+
15
+ def load_generator(config: Dict[str, Any], checkpoint_path: str, device: str = 'cuda') -> nn.Module:
16
+ """Initialize and load the generator model from unwrapped checkpoint."""
17
+ model_cfg = config['model']
18
+
19
+ # Initialize generator based on config
20
+ if model_cfg['name'] == 'MelRNN':
21
+ generator = MelRNN.MelRNN(**model_cfg['params'])
22
+ elif model_cfg['name'] == 'MelRoFormer':
23
+ generator = MelRoFormer.MelRoFormer(**model_cfg['params'])
24
+ elif model_cfg['name'] == 'MelUNet':
25
+ generator = UNet.MelUNet(**model_cfg['params'])
26
+ else:
27
+ raise ValueError(f"Unknown model name: {model_cfg['name']}")
28
+
29
+ # Load unwrapped generator weights
30
+ state_dict = torch.load(checkpoint_path, map_location=device)
31
+ generator.load_state_dict(state_dict)
32
+
33
+ generator = generator.to(device)
34
+ generator.eval()
35
+
36
+ return generator
37
+
38
+
39
+ def process_audio(audio: np.ndarray, generator: nn.Module, device: str = 'cuda') -> np.ndarray:
40
+ """Process a single audio array through the generator."""
41
+ # Convert to tensor: (channels, samples) -> (1, channels, samples)
42
+ if audio.ndim == 1:
43
+ audio = audio[np.newaxis, :] # Add channel dimension for mono
44
+
45
+ audio_tensor = torch.from_numpy(audio).float().to(device)
46
+
47
+ # Run inference
48
+ with torch.no_grad():
49
+ output_tensor = generator(audio_tensor)
50
+
51
+ # Convert back to numpy: (1, channels, samples) -> (channels, samples)
52
+ output_audio = output_tensor.cpu().numpy()
53
+
54
+ return output_audio
55
+
56
+
57
+ def main():
58
+ parser = argparse.ArgumentParser(description="Run inference on audio files using trained generator")
59
+ parser.add_argument("--config", type=str, required=True, help="Path to config.yaml")
60
+ parser.add_argument("--checkpoint", type=str, required=True, help="Path to unwrapped generator weights (.pth)")
61
+ parser.add_argument("--input_dir", type=str, required=True, help="Directory containing input .flac files")
62
+ parser.add_argument("--output_dir", type=str, required=True, help="Directory to save processed audio")
63
+ parser.add_argument("--device", type=str, default="cuda", help="Device to run inference on (cuda/cpu)")
64
+ args = parser.parse_args()
65
+
66
+ # Load config
67
+ with open(args.config, 'r') as f:
68
+ config = yaml.safe_load(f)
69
+
70
+ # Setup paths
71
+ input_dir = Path(args.input_dir)
72
+ output_dir = Path(args.output_dir)
73
+ output_dir.mkdir(parents=True, exist_ok=True)
74
+
75
+ # Get all .flac files
76
+ audio_files = sorted(input_dir.glob("*.flac"))
77
+
78
+ if len(audio_files) == 0:
79
+ print(f"No .flac files found in {input_dir}")
80
+ return
81
+
82
+ print(f"Found {len(audio_files)} audio files")
83
+
84
+ # Load model
85
+ print(f"Loading generator from {args.checkpoint}...")
86
+ generator = load_generator(config, args.checkpoint, device=args.device)
87
+ print("Model loaded successfully")
88
+
89
+ # Process each file
90
+ for audio_file in tqdm(audio_files, desc="Processing audio files"):
91
+ # Load audio
92
+ audio, sr = sf.read(audio_file)
93
+
94
+ # Transpose if needed: soundfile loads as (samples, channels)
95
+ if audio.ndim == 2:
96
+ audio = audio.T # Convert to (channels, samples)
97
+
98
+ # Process through generator
99
+ output_audio = process_audio(audio, generator, device=args.device)
100
+
101
+ # Transpose back for saving: (channels, samples) -> (samples, channels)
102
+ if output_audio.ndim == 2:
103
+ output_audio = output_audio.T
104
+
105
+ # Save with same filename
106
+ output_path = output_dir / audio_file.name
107
+ sf.write(output_path, output_audio, sr)
108
+
109
+ print(f"\nProcessing complete! Output saved to {output_dir}")
110
+
111
+
112
+ if __name__ == '__main__':
113
+ main()
models/MelRNN.py CHANGED
@@ -25,8 +25,8 @@ class MelRNN(nn.Module):
25
 
26
  def forward(self, x):
27
  original_length = x.shape[1]
28
- identity = self.fourier.stft(x)
29
- x = self.band.split(identity) # (B, C, T, F)
30
 
31
  x = rearrange(x, 'b c t f -> b t f c')
32
  b, t, f, c = x.shape
@@ -40,13 +40,12 @@ class MelRNN(nn.Module):
40
  x = rearrange(x, '(b f) t c -> b t f c', f=f)
41
 
42
  x = rearrange(x, 'b t f c -> b c t f')
43
- mask = self.band.unsplit(x)
44
- identity = identity * mask
45
- x = self.fourier.istft(identity, original_length)
46
  return x
47
 
48
  if __name__ == "__main__":
49
- model = MelRNN(hidden_channels=128, num_layers=12, num_groups=4, window_size=2048, hop_size=512, sample_rate=48000)
50
 
51
  x = torch.randn(4, 96000)
52
 
 
25
 
26
  def forward(self, x):
27
  original_length = x.shape[1]
28
+ x = self.fourier.stft(x)
29
+ x = self.band.split(x) # (B, C, T, F)
30
 
31
  x = rearrange(x, 'b c t f -> b t f c')
32
  b, t, f, c = x.shape
 
40
  x = rearrange(x, '(b f) t c -> b t f c', f=f)
41
 
42
  x = rearrange(x, 'b t f c -> b c t f')
43
+ x = self.band.unsplit(x)
44
+ x = self.fourier.istft(x.contiguous(), original_length)
 
45
  return x
46
 
47
  if __name__ == "__main__":
48
+ model = MelRNN(hidden_channels=128, num_layers=9, num_groups=8, window_size=2048, hop_size=512, sample_rate=48000)
49
 
50
  x = torch.randn(4, 96000)
51
 
models/UNet.py CHANGED
@@ -23,8 +23,8 @@ class MelUNet(nn.Module):
23
 
24
  def forward(self, x):
25
  original_length = x.shape[1]
26
- identity = self.fourier.stft(x)
27
- x = self.band.split(identity) # (B, C, T, F)
28
 
29
  residuals = []
30
  for i in range(self.num_layers):
@@ -37,14 +37,13 @@ class MelUNet(nn.Module):
37
  if i < self.num_layers - 1:
38
  x = x + residual
39
 
40
- mask = self.band.unsplit(x)
41
- identity = identity * mask
42
- x = self.fourier.istft(identity, original_length)
43
  return x
44
 
45
  if __name__ == "__main__":
46
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
47
- model = MelUNet(hidden_channels=128, num_layers=2, upsampling_factor=2, window_size=2048, hop_size=512, sample_rate=48000)
48
 
49
  x = torch.randn(4, 96000)
50
  x = x.to(device)
 
23
 
24
  def forward(self, x):
25
  original_length = x.shape[1]
26
+ x = self.fourier.stft(x)
27
+ x = self.band.split(x) # (B, C, T, F)
28
 
29
  residuals = []
30
  for i in range(self.num_layers):
 
37
  if i < self.num_layers - 1:
38
  x = x + residual
39
 
40
+ x = self.band.unsplit(x)
41
+ x = self.fourier.istft(x.contiguous(), original_length)
 
42
  return x
43
 
44
  if __name__ == "__main__":
45
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
46
+ model = MelUNet(hidden_channels=32, num_layers=4, upsampling_factor=2, window_size=2048, hop_size=512, sample_rate=48000)
47
 
48
  x = torch.randn(4, 96000)
49
  x = x.to(device)
modules/generator/ConvNeXt2DBlock.py CHANGED
@@ -35,7 +35,7 @@ class ConvNeXt2DBlock(nn.Module):
35
  self.dwconv = nn.ConvTranspose2d(dim, dim, kernel_size=kernel_size, stride=stride, padding=self.padding)
36
  self.residual_conv = nn.ConvTranspose2d(dim, output_dim, kernel_size=kernel_size, stride=stride, padding=self.padding)
37
  self.norm = RMSNorm(dim)
38
- self.n_hidden = int(8 * dim / 3)
39
  self.pwconv1 = nn.Linear(dim, self.n_hidden * 2)
40
  self.pwconv2 = nn.Linear(self.n_hidden, output_dim)
41
 
 
35
  self.dwconv = nn.ConvTranspose2d(dim, dim, kernel_size=kernel_size, stride=stride, padding=self.padding)
36
  self.residual_conv = nn.ConvTranspose2d(dim, output_dim, kernel_size=kernel_size, stride=stride, padding=self.padding)
37
  self.norm = RMSNorm(dim)
38
+ self.n_hidden = int(4 * dim / 3)
39
  self.pwconv1 = nn.Linear(dim, self.n_hidden * 2)
40
  self.pwconv2 = nn.Linear(self.n_hidden, output_dim)
41
 
train.py CHANGED
@@ -2,7 +2,7 @@ import argparse
2
  import yaml
3
  from pathlib import Path
4
  from typing import Dict, Any, List
5
-
6
  import torch
7
  import torch.nn as nn
8
  from torch.utils.data import DataLoader
@@ -10,16 +10,10 @@ import pytorch_lightning as pl
10
  from pytorch_lightning.loggers import TensorBoardLogger
11
  from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor
12
 
13
- import numpy as np
14
- import soundfile as sf
15
- import matplotlib.pyplot as plt
16
- import librosa
17
-
18
  from data.dataset import RawStems, InfiniteSampler
19
  from models import MelRNN, MelRoFormer, UNet
20
  from losses.gan_loss import GeneratorLoss, DiscriminatorLoss, FeatureMatchingLoss
21
  from losses.reconstruction_loss import MultiMelSpecReconstructionLoss
22
- from evaluation.metrics import SI_SNR, FAD_CLAP
23
 
24
  from modules.discriminator.MultiPeriodDiscriminator import MultiPeriodDiscriminator
25
  from modules.discriminator.MultiScaleDiscriminator import MultiScaleDiscriminator
@@ -55,12 +49,11 @@ class CombinedDiscriminator(nn.Module):
55
  return all_scores, all_fmaps
56
 
57
  class MusicRestorationDataModule(pl.LightningDataModule):
58
- """Handles data loading for training and validation."""
59
  def __init__(self, config: Dict[str, Any]):
60
  super().__init__()
61
  self.config = config
62
  self.train_dataset = None
63
- self.val_dataset = None
64
 
65
  def setup(self, stage: str | None = None):
66
  common_params = {
@@ -68,7 +61,6 @@ class MusicRestorationDataModule(pl.LightningDataModule):
68
  "clip_duration": self.config['clip_duration'],
69
  }
70
  self.train_dataset = RawStems(**self.config['train_dataset'], **common_params)
71
- self.val_dataset = RawStems(**self.config['val_dataset'], **common_params)
72
 
73
  def train_dataloader(self):
74
  sampler = InfiniteSampler(self.train_dataset)
@@ -77,13 +69,6 @@ class MusicRestorationDataModule(pl.LightningDataModule):
77
  sampler=sampler,
78
  **self.config['dataloader_params']
79
  )
80
-
81
- def val_dataloader(self):
82
- return DataLoader(
83
- self.val_dataset,
84
- shuffle=False,
85
- **self.config['dataloader_params']
86
- )
87
 
88
  class MusicRestorationModule(pl.LightningModule):
89
  """
@@ -108,18 +93,12 @@ class MusicRestorationModule(pl.LightningModule):
108
  self.loss_feat = FeatureMatchingLoss()
109
  self.loss_recon = MultiMelSpecReconstructionLoss(**loss_cfg['reconstruction_loss'])
110
 
111
- # 4. Validation Metrics
112
- self.val_si_snr = SI_SNR()
113
- # Note: FAD_CLAP requires `laion_clap` to be installed.
114
- # It will gracefully fall back to random embeddings if not found.
115
- self.val_fad = FAD_CLAP()
116
-
117
  def _init_generator(self):
118
  model_cfg = self.hparams.model
119
  if model_cfg['name'] == 'MelRNN':
120
- return MelRNN(**model_cfg['params'])
121
  elif model_cfg['name'] == 'MelRoFormer':
122
- return MelRoFormer(**model_cfg['params'])
123
  elif model_cfg['name'] == 'MelUNet':
124
  return UNet.MelUNet(**model_cfg['params'])
125
  else:
@@ -133,12 +112,16 @@ class MusicRestorationModule(pl.LightningModule):
133
 
134
  target = batch['target']
135
  mixture = batch['mixture']
 
 
 
 
136
 
137
  # --- Train Discriminator ---
138
  generated = self(mixture)
139
 
140
- real_scores, _ = self.discriminator(target)
141
- fake_scores, _ = self.discriminator(generated.detach())
142
 
143
  d_loss, _, _ = self.loss_disc_adv(real_scores, fake_scores)
144
 
@@ -148,8 +131,8 @@ class MusicRestorationModule(pl.LightningModule):
148
  self.log('train/d_loss', d_loss, prog_bar=True)
149
 
150
  # --- Train Generator ---
151
- real_scores, real_fmaps = self.discriminator(target)
152
- fake_scores, fake_fmaps = self.discriminator(generated)
153
 
154
  # Reconstruction Loss
155
  loss_recon = self.loss_recon(generated, target)
@@ -180,52 +163,6 @@ class MusicRestorationModule(pl.LightningModule):
180
  sch_g, sch_d = self.lr_schedulers()
181
  if sch_g: sch_g.step()
182
  if sch_d: sch_d.step()
183
-
184
- def validation_step(self, batch: Dict[str, torch.Tensor], batch_idx: int):
185
- target = batch['target']
186
- mixture = batch['mixture']
187
-
188
- generated = self(mixture)
189
-
190
- loss_recon = self.loss_recon(generated, target)
191
- self.log('val/loss_recon', loss_recon, on_step=False, on_epoch=True, sync_dist=True)
192
-
193
- self.val_si_snr.update(generated.detach(), target.detach())
194
- self.val_fad.update(generated.detach(), target.detach())
195
-
196
- # Log one audio example and spectrogram per validation epoch
197
- if batch_idx == 0:
198
- self._log_media(mixture[0], target[0], generated[0])
199
-
200
- def on_validation_epoch_end(self):
201
- si_snr_results = self.val_si_snr.compute()
202
- fad_results = self.val_fad.compute()
203
-
204
- self.log('val/si_snr', si_snr_results['mean'], sync_dist=True)
205
- self.log('val/fad', fad_results['fad'], sync_dist=True)
206
-
207
- self.val_si_snr.reset()
208
- self.val_fad.reset()
209
-
210
- def _log_media(self, mixture: torch.Tensor, target: torch.Tensor, generated: torch.Tensor):
211
- sr = self.hparams.data['sample_rate']
212
-
213
- # Log audio
214
- self.logger.experiment.add_audio("val_audio/mixture", mixture.mean(0).cpu(), self.global_step, sample_rate=sr)
215
- self.logger.experiment.add_audio("val_audio/target", target.mean(0).cpu(), self.global_step, sample_rate=sr)
216
- self.logger.experiment.add_audio("val_audio/generated", generated.mean(0).cpu(), self.global_step, sample_rate=sr)
217
-
218
- # Log spectrograms
219
- fig, axes = plt.subplots(3, 1, figsize=(10, 12))
220
- for i, (title, audio) in enumerate([("Mixture", mixture), ("Target", target), ("Generated", generated)]):
221
- audio_np = audio.mean(0).cpu().numpy().astype(np.float32)
222
- mel_spec = librosa.feature.melspectrogram(y=audio_np, sr=sr)
223
- mel_spec_db = librosa.power_to_db(mel_spec, ref=np.max)
224
- librosa.display.specshow(mel_spec_db, sr=sr, x_axis='time', y_axis='mel', ax=axes[i])
225
- axes[i].set_title(title)
226
- plt.tight_layout()
227
- self.logger.experiment.add_figure("val_spectrograms", fig, self.global_step)
228
- plt.close(fig)
229
 
230
  def configure_optimizers(self):
231
  # Generator Optimizer
@@ -248,7 +185,7 @@ class MusicRestorationModule(pl.LightningModule):
248
 
249
  def main():
250
  parser = argparse.ArgumentParser(description="Train a Music Source Restoration Model")
251
- parser.add_argument("--config", type=str, required=True, help="Path to the config.yaml file.")
252
  args = parser.parse_args()
253
 
254
  with open(args.config, 'r') as f:
@@ -258,24 +195,26 @@ def main():
258
 
259
  data_module = MusicRestorationDataModule(config['data'])
260
  model_module = MusicRestorationModule(config)
261
-
262
- save_dir = Path("lightning_logs") / config['project_name'] / config['exp_name']
 
 
263
 
264
  # Callbacks
265
  checkpoint_callback = ModelCheckpoint(
266
  dirpath=save_dir / "checkpoints",
267
- filename="{step:08d}-{val/si_snr:.2f}",
268
- every_n_train_steps=config['trainer']['val_check_interval'],
269
- save_top_k=-1, # Save all checkpoints
270
  auto_insert_metric_name=False
271
  )
272
  lr_monitor = LearningRateMonitor(logging_interval='step')
273
 
274
  # Logger
275
  logger = TensorBoardLogger(
276
- save_dir="lightning_logs",
277
  name=config['project_name'],
278
- version=config['exp_name']
279
  )
280
 
281
  # Trainer
@@ -283,11 +222,10 @@ def main():
283
  logger=logger,
284
  callbacks=[checkpoint_callback, lr_monitor],
285
  max_steps=config['trainer']['max_steps'],
286
- val_check_interval=config['trainer']['val_check_interval'],
287
  log_every_n_steps=config['trainer']['log_every_n_steps'],
288
  devices=config['trainer']['devices'],
289
  precision=config['trainer']['precision'],
290
- accelerator="gpu",
291
  )
292
 
293
  trainer.fit(model_module, datamodule=data_module)
 
2
  import yaml
3
  from pathlib import Path
4
  from typing import Dict, Any, List
5
+ from einops import rearrange
6
  import torch
7
  import torch.nn as nn
8
  from torch.utils.data import DataLoader
 
10
  from pytorch_lightning.loggers import TensorBoardLogger
11
  from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor
12
 
 
 
 
 
 
13
  from data.dataset import RawStems, InfiniteSampler
14
  from models import MelRNN, MelRoFormer, UNet
15
  from losses.gan_loss import GeneratorLoss, DiscriminatorLoss, FeatureMatchingLoss
16
  from losses.reconstruction_loss import MultiMelSpecReconstructionLoss
 
17
 
18
  from modules.discriminator.MultiPeriodDiscriminator import MultiPeriodDiscriminator
19
  from modules.discriminator.MultiScaleDiscriminator import MultiScaleDiscriminator
 
49
  return all_scores, all_fmaps
50
 
51
  class MusicRestorationDataModule(pl.LightningDataModule):
52
+ """Handles data loading for training."""
53
  def __init__(self, config: Dict[str, Any]):
54
  super().__init__()
55
  self.config = config
56
  self.train_dataset = None
 
57
 
58
  def setup(self, stage: str | None = None):
59
  common_params = {
 
61
  "clip_duration": self.config['clip_duration'],
62
  }
63
  self.train_dataset = RawStems(**self.config['train_dataset'], **common_params)
 
64
 
65
  def train_dataloader(self):
66
  sampler = InfiniteSampler(self.train_dataset)
 
69
  sampler=sampler,
70
  **self.config['dataloader_params']
71
  )
 
 
 
 
 
 
 
72
 
73
  class MusicRestorationModule(pl.LightningModule):
74
  """
 
93
  self.loss_feat = FeatureMatchingLoss()
94
  self.loss_recon = MultiMelSpecReconstructionLoss(**loss_cfg['reconstruction_loss'])
95
 
 
 
 
 
 
 
96
  def _init_generator(self):
97
  model_cfg = self.hparams.model
98
  if model_cfg['name'] == 'MelRNN':
99
+ return MelRNN.MelRNN(**model_cfg['params'])
100
  elif model_cfg['name'] == 'MelRoFormer':
101
+ return MelRoFormer.MelRoFormer(**model_cfg['params'])
102
  elif model_cfg['name'] == 'MelUNet':
103
  return UNet.MelUNet(**model_cfg['params'])
104
  else:
 
112
 
113
  target = batch['target']
114
  mixture = batch['mixture']
115
+
116
+ # reshape both from (b, c, t) to ((b, c) t)
117
+ target = rearrange(target, 'b c t -> (b c) t')
118
+ mixture = rearrange(mixture, 'b c t -> (b c) t')
119
 
120
  # --- Train Discriminator ---
121
  generated = self(mixture)
122
 
123
+ real_scores, _ = self.discriminator(target.unsqueeze(1))
124
+ fake_scores, _ = self.discriminator(generated.detach().unsqueeze(1))
125
 
126
  d_loss, _, _ = self.loss_disc_adv(real_scores, fake_scores)
127
 
 
131
  self.log('train/d_loss', d_loss, prog_bar=True)
132
 
133
  # --- Train Generator ---
134
+ real_scores, real_fmaps = self.discriminator(target.unsqueeze(1))
135
+ fake_scores, fake_fmaps = self.discriminator(generated.unsqueeze(1))
136
 
137
  # Reconstruction Loss
138
  loss_recon = self.loss_recon(generated, target)
 
163
  sch_g, sch_d = self.lr_schedulers()
164
  if sch_g: sch_g.step()
165
  if sch_d: sch_d.step()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
166
 
167
  def configure_optimizers(self):
168
  # Generator Optimizer
 
185
 
186
  def main():
187
  parser = argparse.ArgumentParser(description="Train a Music Source Restoration Model")
188
+ parser.add_argument("--config", type=str, required=True, help="Path to the config file.")
189
  args = parser.parse_args()
190
 
191
  with open(args.config, 'r') as f:
 
195
 
196
  data_module = MusicRestorationDataModule(config['data'])
197
  model_module = MusicRestorationModule(config)
198
+
199
+ exp_name = f"{config['model']['name']}"
200
+ exp_name = exp_name.replace(" ", "_")
201
+ save_dir = Path(config['trainer']['save_dir']) / config['project_name'] / exp_name
202
 
203
  # Callbacks
204
  checkpoint_callback = ModelCheckpoint(
205
  dirpath=save_dir / "checkpoints",
206
+ filename="{step:08d}",
207
+ every_n_train_steps=config['trainer']['checkpoint_save_interval'],
208
+ save_top_k=-1,
209
  auto_insert_metric_name=False
210
  )
211
  lr_monitor = LearningRateMonitor(logging_interval='step')
212
 
213
  # Logger
214
  logger = TensorBoardLogger(
215
+ save_dir=config['trainer']['save_dir'],
216
  name=config['project_name'],
217
+ version=exp_name
218
  )
219
 
220
  # Trainer
 
222
  logger=logger,
223
  callbacks=[checkpoint_callback, lr_monitor],
224
  max_steps=config['trainer']['max_steps'],
 
225
  log_every_n_steps=config['trainer']['log_every_n_steps'],
226
  devices=config['trainer']['devices'],
227
  precision=config['trainer']['precision'],
228
+ accelerator="gpu"
229
  )
230
 
231
  trainer.fit(model_module, datamodule=data_module)
unwrap.py CHANGED
@@ -1,5 +1,5 @@
1
  import torch
2
- import argparse
3
  from collections import OrderedDict
4
  from pathlib import Path
5
 
@@ -33,22 +33,9 @@ def unwrap_generator_checkpoint(ckpt_path: str, output_path: str) -> None:
33
  torch.save(generator_state_dict, output_path)
34
 
35
  if __name__ == '__main__':
36
- parser = argparse.ArgumentParser(
37
- description="Unwrap a generator model from a PyTorch Lightning checkpoint."
38
- )
39
- parser.add_argument(
40
- '--ckpt',
41
- type=str,
42
- required=True,
43
- help="Path to the input PyTorch Lightning checkpoint file (.ckpt)."
44
- )
45
- parser.add_argument(
46
- '--out',
47
- type=str,
48
- required=True,
49
- help="Path to save the unwrapped generator weights (.pth)."
50
- )
51
-
52
- args = parser.parse_args()
53
-
54
- unwrap_generator_checkpoint(args.ckpt, args.out)
 
1
  import torch
2
+ import os, glob
3
  from collections import OrderedDict
4
  from pathlib import Path
5
 
 
33
  torch.save(generator_state_dict, output_path)
34
 
35
  if __name__ == '__main__':
36
+ input_dir = "/root/autodl-tmp/checkpoints/mel-unet"
37
+ # find all .ckpt files in the input directory
38
+ ckpt_files = glob.glob(os.path.join(input_dir, '*.ckpt'))
39
+ for ckpt_file in ckpt_files:
40
+ unwrap_generator_checkpoint(ckpt_file, os.path.join(input_dir, os.path.basename(ckpt_file).replace('.ckpt', '.pth')))
41
+ print(f"Unwrapped {ckpt_file} to {os.path.join(input_dir, os.path.basename(ckpt_file).replace('.ckpt', '.pth'))}")