JacobLinCool commited on
Commit
7cdcef9
·
unverified ·
1 Parent(s): 44d0a59

Delete exp/baseline

Browse files
exp/baseline/__init__.py DELETED
File without changes
exp/baseline/data.py DELETED
@@ -1,128 +0,0 @@
1
- import torch
2
- from torch.utils.data import Dataset
3
- import numpy as np
4
- from tqdm import tqdm
5
- from .utils import extract_context
6
-
7
-
8
- class BeatTrackingDataset(Dataset):
9
- def __init__(
10
- self, hf_dataset, target_type="beats", sample_rate=16000, hop_length=160
11
- ):
12
- """
13
- Args:
14
- hf_dataset: HuggingFace dataset object
15
- target_type (str): "beats" or "downbeats". Determines which labels are treated as positive.
16
- """
17
- self.sr = sample_rate
18
- self.hop_length = hop_length
19
- self.target_type = target_type
20
-
21
- # Context window size in samples (7 frames = 70ms at 100fps)
22
- self.context_frames = 7
23
- self.context_samples = (self.context_frames * 2 + 1) * hop_length + max(
24
- [368, 736, 1488]
25
- ) # extra for FFT window
26
-
27
- # Cache audio arrays in memory for fast access
28
- self.audio_cache = []
29
- self.indices = []
30
- self._prepare_indices(hf_dataset)
31
-
32
- def _prepare_indices(self, hf_dataset):
33
- """
34
- Prepares balanced indices and caches audio.
35
- Paper Section 4.5: Uses "Fuzzier" training examples (neighbors weighted less).
36
- """
37
- print(f"Preparing dataset indices for target: {self.target_type}...")
38
-
39
- for i, item in tqdm(
40
- enumerate(hf_dataset), total=len(hf_dataset), desc="Building indices"
41
- ):
42
- # Cache audio array (convert to numpy if tensor)
43
- audio = item["audio"]["array"]
44
- if hasattr(audio, "numpy"):
45
- audio = audio.numpy()
46
- self.audio_cache.append(audio)
47
-
48
- # Calculate total frames available in audio
49
- audio_len = len(audio)
50
- n_frames = int(audio_len / self.hop_length)
51
-
52
- # Select ground truth based on target_type
53
- if self.target_type == "downbeats":
54
- # Only downbeats are positives
55
- gt_times = item["downbeats"]
56
- else:
57
- # All beats are positives (downbeats are also beats)
58
- gt_times = item["beats"]
59
-
60
- # Convert to list if tensor
61
- if hasattr(gt_times, "tolist"):
62
- gt_times = gt_times.tolist()
63
-
64
- gt_frames = set([int(t * self.sr / self.hop_length) for t in gt_times])
65
-
66
- # --- Positive Examples (with Fuzziness) ---
67
- # "define a single frame before and after each annotated onset to be additional positive examples"
68
- pos_frames = set()
69
- for bf in gt_frames:
70
- if 0 <= bf < n_frames:
71
- self.indices.append((i, bf, 1.0)) # Center frame (Sharp onset)
72
- pos_frames.add(bf)
73
-
74
- # Neighbors weighted at 0.25
75
- if 0 <= bf - 1 < n_frames:
76
- self.indices.append((i, bf - 1, 0.25))
77
- pos_frames.add(bf - 1)
78
- if 0 <= bf + 1 < n_frames:
79
- self.indices.append((i, bf + 1, 0.25))
80
- pos_frames.add(bf + 1)
81
-
82
- # --- Negative Examples ---
83
- # Paper uses "all others as negative", but we balance 2:1 for stable SGD.
84
- num_pos = len(pos_frames)
85
- num_neg = num_pos * 2
86
-
87
- count = 0
88
- attempts = 0
89
- while count < num_neg and attempts < num_neg * 5:
90
- f = np.random.randint(0, n_frames)
91
- if f not in pos_frames:
92
- self.indices.append((i, f, 0.0))
93
- count += 1
94
- attempts += 1
95
-
96
- print(
97
- f"Dataset ready. {len(self.indices)} samples, {len(self.audio_cache)} tracks cached."
98
- )
99
-
100
- def __len__(self):
101
- return len(self.indices)
102
-
103
- def __getitem__(self, idx):
104
- track_idx, frame_idx, label = self.indices[idx]
105
-
106
- # Fast lookup from cache
107
- audio = self.audio_cache[track_idx]
108
- audio_len = len(audio)
109
-
110
- # Calculate sample range for context window
111
- center_sample = frame_idx * self.hop_length
112
- half_context = self.context_samples // 2
113
- start = center_sample - half_context
114
- end = center_sample + half_context
115
-
116
- # Handle padding if needed
117
- pad_left = max(0, -start)
118
- pad_right = max(0, end - audio_len)
119
- start = max(0, start)
120
- end = min(audio_len, end)
121
-
122
- # Extract audio chunk
123
- chunk = audio[start:end]
124
- if pad_left > 0 or pad_right > 0:
125
- chunk = np.pad(chunk, (pad_left, pad_right), mode="constant")
126
-
127
- waveform = torch.tensor(chunk, dtype=torch.float32)
128
- return waveform, torch.tensor([label], dtype=torch.float32)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
exp/baseline/eval.py DELETED
@@ -1,326 +0,0 @@
1
- import torch
2
- import numpy as np
3
- from tqdm import tqdm
4
- from scipy.signal import find_peaks
5
- import argparse
6
- import os
7
-
8
- from .model import ODCNN
9
- from .utils import MultiViewSpectrogram
10
- from ..data.load import ds
11
- from ..data.eval import evaluate_all, format_results
12
-
13
-
14
- def get_activation_function(model, waveform, device):
15
- """
16
- Computes probability curve over time.
17
- """
18
- processor = MultiViewSpectrogram().to(device)
19
- waveform = waveform.unsqueeze(0).to(device)
20
-
21
- with torch.no_grad():
22
- spec = processor(waveform)
23
-
24
- # Normalize
25
- mean = spec.mean(dim=(2, 3), keepdim=True)
26
- std = spec.std(dim=(2, 3), keepdim=True) + 1e-6
27
- spec = (spec - mean) / std
28
-
29
- # Batchify with sliding window
30
- spec = torch.nn.functional.pad(spec, (7, 7)) # Pad time
31
- windows = spec.unfold(3, 15, 1) # (1, 3, 80, Time, 15)
32
- windows = windows.permute(3, 0, 1, 2, 4).squeeze(1) # (Time, 3, 80, 15)
33
-
34
- # Inference
35
- activations = []
36
- batch_size = 512
37
- for i in range(0, len(windows), batch_size):
38
- batch = windows[i : i + batch_size]
39
- out = model(batch)
40
- activations.append(out.cpu().numpy())
41
-
42
- return np.concatenate(activations).flatten()
43
-
44
-
45
- def pick_peaks(activations, hop_length=160, sr=16000):
46
- """
47
- Smooth with Hamming window and report local maxima.
48
- """
49
- # Smoothing
50
- window = np.hamming(5)
51
- window /= window.sum()
52
- smoothed = np.convolve(activations, window, mode="same")
53
-
54
- # Peak Picking
55
- peaks, _ = find_peaks(smoothed, height=0.5, distance=5)
56
-
57
- timestamps = peaks * hop_length / sr
58
- return timestamps.tolist()
59
-
60
-
61
- def visualize_track(
62
- audio: np.ndarray,
63
- sr: int,
64
- pred_beats: list[float],
65
- pred_downbeats: list[float],
66
- gt_beats: list[float],
67
- gt_downbeats: list[float],
68
- output_dir: str,
69
- track_idx: int,
70
- time_range: tuple[float, float] | None = None,
71
- ):
72
- """
73
- Create and save visualizations for a single track.
74
- """
75
- from ..data.viz import plot_waveform_with_beats, save_figure
76
-
77
- os.makedirs(output_dir, exist_ok=True)
78
-
79
- # Full waveform plot
80
- fig = plot_waveform_with_beats(
81
- audio,
82
- sr,
83
- pred_beats,
84
- gt_beats,
85
- pred_downbeats,
86
- gt_downbeats,
87
- title=f"Track {track_idx}: Beat Comparison",
88
- time_range=time_range,
89
- )
90
- save_figure(fig, os.path.join(output_dir, f"track_{track_idx:03d}.png"))
91
-
92
-
93
- def synthesize_audio(
94
- audio: np.ndarray,
95
- sr: int,
96
- pred_beats: list[float],
97
- pred_downbeats: list[float],
98
- gt_beats: list[float],
99
- gt_downbeats: list[float],
100
- output_dir: str,
101
- track_idx: int,
102
- click_volume: float = 0.5,
103
- ):
104
- """
105
- Create and save audio files with click tracks for a single track.
106
- """
107
- from ..data.audio import create_comparison_audio, save_audio
108
-
109
- os.makedirs(output_dir, exist_ok=True)
110
-
111
- # Create comparison audio
112
- audio_pred, audio_gt, audio_both = create_comparison_audio(
113
- audio,
114
- pred_beats,
115
- pred_downbeats,
116
- gt_beats,
117
- gt_downbeats,
118
- sr=sr,
119
- click_volume=click_volume,
120
- )
121
-
122
- # Save audio files
123
- save_audio(
124
- audio_pred, os.path.join(output_dir, f"track_{track_idx:03d}_pred.wav"), sr
125
- )
126
- save_audio(audio_gt, os.path.join(output_dir, f"track_{track_idx:03d}_gt.wav"), sr)
127
- save_audio(
128
- audio_both, os.path.join(output_dir, f"track_{track_idx:03d}_both.wav"), sr
129
- )
130
-
131
-
132
- def main():
133
- parser = argparse.ArgumentParser(
134
- description="Evaluate beat tracking models with visualization and audio synthesis"
135
- )
136
- parser.add_argument(
137
- "--model-dir",
138
- type=str,
139
- default="outputs/baseline",
140
- help="Base directory containing trained models (with 'beats' and 'downbeats' subdirs)",
141
- )
142
- parser.add_argument(
143
- "--num-samples",
144
- type=int,
145
- default=20,
146
- help="Number of samples to evaluate",
147
- )
148
- parser.add_argument(
149
- "--output-dir",
150
- type=str,
151
- default="outputs/eval",
152
- help="Directory to save visualizations and audio",
153
- )
154
- parser.add_argument(
155
- "--visualize",
156
- action="store_true",
157
- help="Generate visualization plots for each track",
158
- )
159
- parser.add_argument(
160
- "--synthesize",
161
- action="store_true",
162
- help="Generate audio files with click tracks",
163
- )
164
- parser.add_argument(
165
- "--viz-tracks",
166
- type=int,
167
- default=5,
168
- help="Number of tracks to visualize/synthesize (default: 5)",
169
- )
170
- parser.add_argument(
171
- "--time-range",
172
- type=float,
173
- nargs=2,
174
- default=None,
175
- metavar=("START", "END"),
176
- help="Time range for visualization in seconds (default: full track)",
177
- )
178
- parser.add_argument(
179
- "--click-volume",
180
- type=float,
181
- default=0.5,
182
- help="Volume of click sounds relative to audio (0.0 to 1.0)",
183
- )
184
- parser.add_argument(
185
- "--summary-plot",
186
- action="store_true",
187
- help="Generate summary evaluation plot",
188
- )
189
- args = parser.parse_args()
190
-
191
- DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
192
-
193
- # Load BOTH models using from_pretrained
194
- beat_model = None
195
- downbeat_model = None
196
-
197
- has_beats = False
198
- has_downbeats = False
199
-
200
- beats_dir = os.path.join(args.model_dir, "beats")
201
- downbeats_dir = os.path.join(args.model_dir, "downbeats")
202
-
203
- if os.path.exists(os.path.join(beats_dir, "model.safetensors")) or os.path.exists(
204
- os.path.join(beats_dir, "pytorch_model.bin")
205
- ):
206
- beat_model = ODCNN.from_pretrained(beats_dir).to(DEVICE)
207
- beat_model.eval()
208
- has_beats = True
209
- print(f"Loaded Beat Model from {beats_dir}")
210
- else:
211
- print(f"Warning: No beat model found in {beats_dir}")
212
-
213
- if os.path.exists(
214
- os.path.join(downbeats_dir, "model.safetensors")
215
- ) or os.path.exists(os.path.join(downbeats_dir, "pytorch_model.bin")):
216
- downbeat_model = ODCNN.from_pretrained(downbeats_dir).to(DEVICE)
217
- downbeat_model.eval()
218
- has_downbeats = True
219
- print(f"Loaded Downbeat Model from {downbeats_dir}")
220
- else:
221
- print(f"Warning: No downbeat model found in {downbeats_dir}")
222
-
223
- if not has_beats and not has_downbeats:
224
- print("No models found. Please run training first.")
225
- return
226
-
227
- predictions = []
228
- ground_truths = []
229
- audio_data = [] # Store audio for visualization/synthesis
230
-
231
- # Eval on specified number of tracks
232
- test_set = ds["train"].select(range(args.num_samples))
233
-
234
- print("Running evaluation...")
235
- for i, item in enumerate(tqdm(test_set)):
236
- waveform = torch.tensor(item["audio"]["array"], dtype=torch.float32)
237
- waveform_device = waveform.to(DEVICE)
238
-
239
- pred_entry = {"beats": [], "downbeats": []}
240
-
241
- # 1. Predict Beats
242
- if has_beats:
243
- act_b = get_activation_function(beat_model, waveform_device, DEVICE)
244
- pred_entry["beats"] = pick_peaks(act_b)
245
-
246
- # 2. Predict Downbeats
247
- if has_downbeats:
248
- act_d = get_activation_function(downbeat_model, waveform_device, DEVICE)
249
- pred_entry["downbeats"] = pick_peaks(act_d)
250
-
251
- predictions.append(pred_entry)
252
- ground_truths.append({"beats": item["beats"], "downbeats": item["downbeats"]})
253
-
254
- # Store audio for later visualization/synthesis
255
- if args.visualize or args.synthesize:
256
- if i < args.viz_tracks:
257
- audio_data.append(
258
- {
259
- "audio": waveform.numpy(),
260
- "sr": item["audio"]["sampling_rate"],
261
- "pred": pred_entry,
262
- "gt": ground_truths[-1],
263
- }
264
- )
265
-
266
- # Run evaluation
267
- results = evaluate_all(predictions, ground_truths)
268
- print(format_results(results))
269
-
270
- # Create output directory
271
- if args.visualize or args.synthesize or args.summary_plot:
272
- os.makedirs(args.output_dir, exist_ok=True)
273
-
274
- # Generate visualizations
275
- if args.visualize:
276
- print(f"\nGenerating visualizations for {len(audio_data)} tracks...")
277
- viz_dir = os.path.join(args.output_dir, "plots")
278
- for i, data in enumerate(tqdm(audio_data, desc="Visualizing")):
279
- time_range = tuple(args.time_range) if args.time_range else None
280
- visualize_track(
281
- data["audio"],
282
- data["sr"],
283
- data["pred"]["beats"],
284
- data["pred"]["downbeats"],
285
- data["gt"]["beats"],
286
- data["gt"]["downbeats"],
287
- viz_dir,
288
- i,
289
- time_range=time_range,
290
- )
291
- print(f"Saved visualizations to {viz_dir}")
292
-
293
- # Generate audio with clicks
294
- if args.synthesize:
295
- print(f"\nSynthesizing audio for {len(audio_data)} tracks...")
296
- audio_dir = os.path.join(args.output_dir, "audio")
297
- for i, data in enumerate(tqdm(audio_data, desc="Synthesizing")):
298
- synthesize_audio(
299
- data["audio"],
300
- data["sr"],
301
- data["pred"]["beats"],
302
- data["pred"]["downbeats"],
303
- data["gt"]["beats"],
304
- data["gt"]["downbeats"],
305
- audio_dir,
306
- i,
307
- click_volume=args.click_volume,
308
- )
309
- print(f"Saved audio files to {audio_dir}")
310
- print(" *_pred.wav - Original audio with predicted beat clicks")
311
- print(" *_gt.wav - Original audio with ground truth beat clicks")
312
- print(" *_both.wav - Original audio with both predicted and GT clicks")
313
-
314
- # Generate summary plot
315
- if args.summary_plot:
316
- from ..data.viz import plot_evaluation_summary, save_figure
317
-
318
- print("\nGenerating summary plot...")
319
- fig = plot_evaluation_summary(results, title="Beat Tracking Evaluation Summary")
320
- summary_path = os.path.join(args.output_dir, "evaluation_summary.png")
321
- save_figure(fig, summary_path)
322
- print(f"Saved summary plot to {summary_path}")
323
-
324
-
325
- if __name__ == "__main__":
326
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
exp/baseline/model.py DELETED
@@ -1,62 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
- from huggingface_hub import PyTorchModelHubMixin
4
-
5
-
6
- class ODCNN(nn.Module, PyTorchModelHubMixin):
7
- def __init__(self, dropout_rate=0.5):
8
- super().__init__()
9
-
10
- # Input 3 channels, 80 bands
11
- # Conv 1: 7x3 filters -> 10 maps
12
- self.conv1 = nn.Conv2d(3, 10, kernel_size=(3, 7))
13
- self.relu1 = nn.ReLU() # ReLU improvement
14
- self.pool1 = nn.MaxPool2d(kernel_size=(3, 1), stride=(3, 1))
15
-
16
- # Conv 2: 3x3 filters -> 20 maps
17
- self.conv2 = nn.Conv2d(10, 20, kernel_size=(3, 3))
18
- self.relu2 = nn.ReLU()
19
- self.pool2 = nn.MaxPool2d(kernel_size=(3, 1), stride=(3, 1))
20
-
21
- # Flatten size calculation based on architecture
22
- # (20 feature maps * 8 freq bands * 7 time frames)
23
- self.flatten_size = 20 * 8 * 7
24
-
25
- # Dropout on FC inputs
26
- self.dropout = nn.Dropout(p=dropout_rate)
27
-
28
- # 256 Hidden Units
29
- self.fc1 = nn.Linear(self.flatten_size, 256)
30
- self.relu_fc = nn.ReLU()
31
-
32
- # Output Unit
33
- self.fc2 = nn.Linear(256, 1)
34
- self.sigmoid = nn.Sigmoid()
35
-
36
- def forward(self, x):
37
- x = self.conv1(x)
38
- x = self.relu1(x)
39
- x = self.pool1(x)
40
-
41
- x = self.conv2(x)
42
- x = self.relu2(x)
43
- x = self.pool2(x)
44
-
45
- x = x.view(x.size(0), -1)
46
-
47
- x = self.dropout(x)
48
- x = self.fc1(x)
49
- x = self.relu_fc(x)
50
-
51
- x = self.dropout(x)
52
- x = self.fc2(x)
53
- x = self.sigmoid(x)
54
-
55
- return x
56
-
57
-
58
- if __name__ == "__main__":
59
- from torchinfo import summary
60
-
61
- model = ODCNN()
62
- summary(model, (1, 3, 80, 15))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
exp/baseline/train.py DELETED
@@ -1,183 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
- import torch.optim as optim
4
- from torch.utils.data import DataLoader
5
- from torch.utils.tensorboard import SummaryWriter
6
- from tqdm import tqdm
7
- import argparse
8
- import os
9
-
10
- from .model import ODCNN
11
- from .data import BeatTrackingDataset
12
- from .utils import MultiViewSpectrogram
13
- from ..data.load import ds
14
-
15
-
16
- def train(target_type: str, output_dir: str):
17
- # Note: Paper uses SGD with Momentum, Dropout, and ReLU
18
- DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
19
- BATCH_SIZE = 512
20
- EPOCHS = 50
21
- LR = 0.05
22
- MOMENTUM = 0.9
23
- NUM_WORKERS = 4
24
-
25
- print(f"--- Training Model for target: {target_type} ---")
26
- print(f"Output directory: {output_dir}")
27
-
28
- # Create output directory
29
- os.makedirs(output_dir, exist_ok=True)
30
-
31
- # TensorBoard writer
32
- writer = SummaryWriter(log_dir=os.path.join(output_dir, "logs"))
33
-
34
- # Data - use existing train/test splits
35
- train_dataset = BeatTrackingDataset(ds["train"], target_type=target_type)
36
- val_dataset = BeatTrackingDataset(ds["test"], target_type=target_type)
37
-
38
- train_loader = DataLoader(
39
- train_dataset,
40
- batch_size=BATCH_SIZE,
41
- shuffle=True,
42
- num_workers=NUM_WORKERS,
43
- pin_memory=True,
44
- prefetch_factor=4,
45
- persistent_workers=True,
46
- )
47
- val_loader = DataLoader(
48
- val_dataset,
49
- batch_size=BATCH_SIZE,
50
- shuffle=False,
51
- num_workers=NUM_WORKERS,
52
- pin_memory=True,
53
- prefetch_factor=4,
54
- persistent_workers=True,
55
- )
56
-
57
- print(f"Train samples: {len(train_dataset)}, Val samples: {len(val_dataset)}")
58
-
59
- # Model
60
- model = ODCNN(dropout_rate=0.5).to(DEVICE)
61
-
62
- # GPU Spectrogram Preprocessor
63
- preprocessor = MultiViewSpectrogram(sample_rate=16000, hop_length=160).to(DEVICE)
64
-
65
- # Optimizer
66
- optimizer = optim.SGD(model.parameters(), lr=LR, momentum=MOMENTUM)
67
- scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS)
68
- criterion = nn.BCELoss() # Binary Cross Entropy
69
-
70
- best_val_loss = float("inf")
71
- global_step = 0
72
-
73
- for epoch in range(EPOCHS):
74
- # Training
75
- model.train()
76
- total_train_loss = 0
77
- for waveform, y in tqdm(
78
- train_loader,
79
- desc=f"[{target_type}] Epoch {epoch + 1}/{EPOCHS} Train",
80
- leave=False,
81
- ):
82
- waveform, y = waveform.to(DEVICE), y.to(DEVICE)
83
-
84
- # Compute spectrogram on GPU
85
- with torch.no_grad():
86
- spec = preprocessor(waveform) # (B, 3, 80, T)
87
- # Normalize
88
- mean = spec.mean(dim=(2, 3), keepdim=True)
89
- std = spec.std(dim=(2, 3), keepdim=True) + 1e-6
90
- spec = (spec - mean) / std
91
- # Extract center context (T should be ~15 frames)
92
- x = spec[:, :, :, 7:22] # center 15 frames
93
-
94
- optimizer.zero_grad()
95
- output = model(x)
96
- loss = criterion(output, y)
97
- loss.backward()
98
- optimizer.step()
99
-
100
- total_train_loss += loss.item()
101
- global_step += 1
102
-
103
- # Log batch loss
104
- writer.add_scalar("train/batch_loss", loss.item(), global_step)
105
-
106
- avg_train_loss = total_train_loss / len(train_loader)
107
-
108
- # Validation
109
- model.eval()
110
- total_val_loss = 0
111
- with torch.no_grad():
112
- for waveform, y in tqdm(
113
- val_loader,
114
- desc=f"[{target_type}] Epoch {epoch + 1}/{EPOCHS} Val",
115
- leave=False,
116
- ):
117
- waveform, y = waveform.to(DEVICE), y.to(DEVICE)
118
-
119
- # Compute spectrogram on GPU
120
- spec = preprocessor(waveform) # (B, 3, 80, T)
121
- # Normalize
122
- mean = spec.mean(dim=(2, 3), keepdim=True)
123
- std = spec.std(dim=(2, 3), keepdim=True) + 1e-6
124
- spec = (spec - mean) / std
125
- # Extract center context
126
- x = spec[:, :, :, 7:22]
127
-
128
- output = model(x)
129
- loss = criterion(output, y)
130
- total_val_loss += loss.item()
131
-
132
- avg_val_loss = total_val_loss / len(val_loader)
133
-
134
- # Log epoch metrics
135
- writer.add_scalar("train/epoch_loss", avg_train_loss, epoch)
136
- writer.add_scalar("val/loss", avg_val_loss, epoch)
137
- writer.add_scalar("train/learning_rate", scheduler.get_last_lr()[0], epoch)
138
-
139
- # Step the scheduler
140
- scheduler.step()
141
-
142
- print(
143
- f"[{target_type}] Epoch {epoch + 1}/{EPOCHS} - "
144
- f"Train Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}"
145
- )
146
-
147
- # Save best model
148
- if avg_val_loss < best_val_loss:
149
- best_val_loss = avg_val_loss
150
- model.save_pretrained(output_dir)
151
- print(f" -> Saved best model (val_loss: {best_val_loss:.4f})")
152
-
153
- writer.close()
154
-
155
- # Save final model
156
- final_dir = os.path.join(output_dir, "final")
157
- model.save_pretrained(final_dir)
158
- print(f"Saved final model to {final_dir}")
159
-
160
-
161
- if __name__ == "__main__":
162
- parser = argparse.ArgumentParser()
163
- parser.add_argument(
164
- "--target",
165
- type=str,
166
- choices=["beats", "downbeats"],
167
- default=None,
168
- help="Train a model for 'beats' or 'downbeats'. If not specified, trains both.",
169
- )
170
- parser.add_argument(
171
- "--output-dir",
172
- type=str,
173
- default="outputs/baseline",
174
- help="Directory to save model and logs",
175
- )
176
- args = parser.parse_args()
177
-
178
- # Determine which targets to train
179
- targets = [args.target] if args.target else ["beats", "downbeats"]
180
-
181
- for target in targets:
182
- output_dir = os.path.join(args.output_dir, target)
183
- train(target, output_dir)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
exp/baseline/utils.py DELETED
@@ -1,53 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
- import torchaudio.transforms as T
4
- import numpy as np
5
-
6
-
7
- class MultiViewSpectrogram(nn.Module):
8
- def __init__(self, sample_rate=16000, n_mels=80, hop_length=160):
9
- super().__init__()
10
- # Window sizes: 23ms, 46ms, 93ms
11
- self.win_lengths = [368, 736, 1488]
12
- self.transforms = nn.ModuleList()
13
-
14
- for win_len in self.win_lengths:
15
- n_fft = 2 ** int(np.ceil(np.log2(win_len)))
16
- mel = T.MelSpectrogram(
17
- sample_rate=sample_rate,
18
- n_fft=n_fft,
19
- win_length=win_len,
20
- hop_length=hop_length,
21
- f_min=27.5,
22
- f_max=16000.0,
23
- n_mels=n_mels,
24
- power=1.0,
25
- center=True,
26
- )
27
- self.transforms.append(mel)
28
-
29
- def forward(self, waveform):
30
- specs = []
31
- for transform in self.transforms:
32
- # Scale magnitudes logarithmically
33
- s = transform(waveform)
34
- s = torch.log(s + 1e-9)
35
- specs.append(s)
36
- return torch.stack(specs, dim=1)
37
-
38
-
39
- def extract_context(spec, center_frame, context=7):
40
- # Context of +/- 70ms (7 frames)
41
- channels, n_mels, total_time = spec.shape
42
- start = center_frame - context
43
- end = center_frame + context + 1
44
-
45
- pad_left = max(0, -start)
46
- pad_right = max(0, end - total_time)
47
-
48
- if pad_left > 0 or pad_right > 0:
49
- spec = torch.nn.functional.pad(spec, (pad_left, pad_right))
50
- start += pad_left
51
- end += pad_left
52
-
53
- return spec[:, :, start:end]