|
|
import torch |
|
|
import numpy as np |
|
|
from tqdm import tqdm |
|
|
from scipy.signal import find_peaks |
|
|
import argparse |
|
|
import os |
|
|
|
|
|
from .model import ResNet |
|
|
from ..baseline1.utils import MultiViewSpectrogram |
|
|
from ..data.load import ds |
|
|
from ..data.eval import evaluate_all, format_results |
|
|
|
|
|
|
|
|
def get_activation_function(model, waveform, device): |
|
|
""" |
|
|
Computes probability curve over time. |
|
|
""" |
|
|
processor = MultiViewSpectrogram().to(device) |
|
|
waveform = waveform.unsqueeze(0).to(device) |
|
|
|
|
|
with torch.no_grad(): |
|
|
spec = processor(waveform) |
|
|
|
|
|
|
|
|
mean = spec.mean(dim=(2, 3), keepdim=True) |
|
|
std = spec.std(dim=(2, 3), keepdim=True) + 1e-6 |
|
|
spec = (spec - mean) / std |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
spec = torch.nn.functional.pad(spec, (50, 50)) |
|
|
windows = spec.unfold(3, 101, 1) |
|
|
windows = windows.permute(3, 0, 1, 2, 4).squeeze(1) |
|
|
|
|
|
|
|
|
activations = [] |
|
|
batch_size = 128 |
|
|
for i in range(0, len(windows), batch_size): |
|
|
batch = windows[i : i + batch_size] |
|
|
out = model(batch) |
|
|
activations.append(out.cpu().numpy()) |
|
|
|
|
|
return np.concatenate(activations).flatten() |
|
|
|
|
|
|
|
|
def pick_peaks(activations, hop_length=160, sr=16000): |
|
|
""" |
|
|
Smooth with Hamming window and report local maxima. |
|
|
""" |
|
|
|
|
|
window = np.hamming(5) |
|
|
window /= window.sum() |
|
|
smoothed = np.convolve(activations, window, mode="same") |
|
|
|
|
|
|
|
|
peaks, _ = find_peaks(smoothed, height=0.5, distance=5) |
|
|
|
|
|
timestamps = peaks * hop_length / sr |
|
|
return timestamps.tolist() |
|
|
|
|
|
|
|
|
def visualize_track( |
|
|
audio: np.ndarray, |
|
|
sr: int, |
|
|
pred_beats: list[float], |
|
|
pred_downbeats: list[float], |
|
|
gt_beats: list[float], |
|
|
gt_downbeats: list[float], |
|
|
output_dir: str, |
|
|
track_idx: int, |
|
|
time_range: tuple[float, float] | None = None, |
|
|
): |
|
|
""" |
|
|
Create and save visualizations for a single track. |
|
|
""" |
|
|
from ..data.viz import plot_waveform_with_beats, save_figure |
|
|
|
|
|
os.makedirs(output_dir, exist_ok=True) |
|
|
|
|
|
|
|
|
fig = plot_waveform_with_beats( |
|
|
audio, |
|
|
sr, |
|
|
pred_beats, |
|
|
gt_beats, |
|
|
pred_downbeats, |
|
|
gt_downbeats, |
|
|
title=f"Track {track_idx}: Beat Comparison", |
|
|
time_range=time_range, |
|
|
) |
|
|
save_figure(fig, os.path.join(output_dir, f"track_{track_idx:03d}.png")) |
|
|
|
|
|
|
|
|
def synthesize_audio( |
|
|
audio: np.ndarray, |
|
|
sr: int, |
|
|
pred_beats: list[float], |
|
|
pred_downbeats: list[float], |
|
|
gt_beats: list[float], |
|
|
gt_downbeats: list[float], |
|
|
output_dir: str, |
|
|
track_idx: int, |
|
|
click_volume: float = 0.5, |
|
|
): |
|
|
""" |
|
|
Create and save audio files with click tracks for a single track. |
|
|
""" |
|
|
from ..data.audio import create_comparison_audio, save_audio |
|
|
|
|
|
os.makedirs(output_dir, exist_ok=True) |
|
|
|
|
|
|
|
|
audio_pred, audio_gt, audio_both = create_comparison_audio( |
|
|
audio, |
|
|
pred_beats, |
|
|
pred_downbeats, |
|
|
gt_beats, |
|
|
gt_downbeats, |
|
|
sr=sr, |
|
|
click_volume=click_volume, |
|
|
) |
|
|
|
|
|
|
|
|
save_audio( |
|
|
audio_pred, os.path.join(output_dir, f"track_{track_idx:03d}_pred.wav"), sr |
|
|
) |
|
|
save_audio(audio_gt, os.path.join(output_dir, f"track_{track_idx:03d}_gt.wav"), sr) |
|
|
save_audio( |
|
|
audio_both, os.path.join(output_dir, f"track_{track_idx:03d}_both.wav"), sr |
|
|
) |
|
|
|
|
|
|
|
|
def main(): |
|
|
parser = argparse.ArgumentParser( |
|
|
description="Evaluate beat tracking models with visualization and audio synthesis" |
|
|
) |
|
|
parser.add_argument( |
|
|
"--model-dir", |
|
|
type=str, |
|
|
default="outputs/baseline2", |
|
|
help="Base directory containing trained models (with 'beats' and 'downbeats' subdirs)", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--num-samples", |
|
|
type=int, |
|
|
default=116, |
|
|
help="Number of samples to evaluate", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--output-dir", |
|
|
type=str, |
|
|
default="outputs/eval_baseline2", |
|
|
help="Directory to save visualizations and audio", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--visualize", |
|
|
action="store_true", |
|
|
help="Generate visualization plots for each track", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--synthesize", |
|
|
action="store_true", |
|
|
help="Generate audio files with click tracks", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--viz-tracks", |
|
|
type=int, |
|
|
default=5, |
|
|
help="Number of tracks to visualize/synthesize (default: 5)", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--time-range", |
|
|
type=float, |
|
|
nargs=2, |
|
|
default=None, |
|
|
metavar=("START", "END"), |
|
|
help="Time range for visualization in seconds (default: full track)", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--click-volume", |
|
|
type=float, |
|
|
default=0.5, |
|
|
help="Volume of click sounds relative to audio (0.0 to 1.0)", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--summary-plot", |
|
|
action="store_true", |
|
|
help="Generate summary evaluation plot", |
|
|
) |
|
|
args = parser.parse_args() |
|
|
|
|
|
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
|
|
|
|
beat_model = None |
|
|
downbeat_model = None |
|
|
|
|
|
has_beats = False |
|
|
has_downbeats = False |
|
|
|
|
|
beats_dir = os.path.join(args.model_dir, "beats") |
|
|
downbeats_dir = os.path.join(args.model_dir, "downbeats") |
|
|
|
|
|
if os.path.exists(os.path.join(beats_dir, "model.safetensors")): |
|
|
beat_model = ResNet.from_pretrained(beats_dir).to(DEVICE) |
|
|
beat_model.eval() |
|
|
has_beats = True |
|
|
print(f"Loaded Beat Model from {beats_dir}") |
|
|
else: |
|
|
print(f"Warning: No beat model found in {beats_dir}") |
|
|
|
|
|
if os.path.exists(os.path.join(downbeats_dir, "model.safetensors")): |
|
|
downbeat_model = ResNet.from_pretrained(downbeats_dir).to(DEVICE) |
|
|
downbeat_model.eval() |
|
|
has_downbeats = True |
|
|
print(f"Loaded Downbeat Model from {downbeats_dir}") |
|
|
else: |
|
|
print(f"Warning: No downbeat model found in {downbeats_dir}") |
|
|
|
|
|
if not has_beats and not has_downbeats: |
|
|
print("No models found. Please run training first.") |
|
|
return |
|
|
|
|
|
predictions = [] |
|
|
ground_truths = [] |
|
|
audio_data = [] |
|
|
|
|
|
|
|
|
test_set = ds["train"].select(range(args.num_samples)) |
|
|
|
|
|
print("Running evaluation...") |
|
|
for i, item in enumerate(tqdm(test_set)): |
|
|
waveform = torch.tensor(item["audio"]["array"], dtype=torch.float32) |
|
|
waveform_device = waveform.to(DEVICE) |
|
|
|
|
|
pred_entry = {"beats": [], "downbeats": []} |
|
|
|
|
|
|
|
|
if has_beats: |
|
|
act_b = get_activation_function(beat_model, waveform_device, DEVICE) |
|
|
pred_entry["beats"] = pick_peaks(act_b) |
|
|
|
|
|
|
|
|
if has_downbeats: |
|
|
act_d = get_activation_function(downbeat_model, waveform_device, DEVICE) |
|
|
pred_entry["downbeats"] = pick_peaks(act_d) |
|
|
|
|
|
predictions.append(pred_entry) |
|
|
ground_truths.append({"beats": item["beats"], "downbeats": item["downbeats"]}) |
|
|
|
|
|
|
|
|
if args.visualize or args.synthesize: |
|
|
if i < args.viz_tracks: |
|
|
audio_data.append( |
|
|
{ |
|
|
"audio": waveform.numpy(), |
|
|
"sr": item["audio"]["sampling_rate"], |
|
|
"pred": pred_entry, |
|
|
"gt": ground_truths[-1], |
|
|
} |
|
|
) |
|
|
|
|
|
|
|
|
results = evaluate_all(predictions, ground_truths) |
|
|
print(format_results(results)) |
|
|
|
|
|
|
|
|
if args.visualize or args.synthesize or args.summary_plot: |
|
|
os.makedirs(args.output_dir, exist_ok=True) |
|
|
|
|
|
|
|
|
if args.visualize: |
|
|
print(f"\nGenerating visualizations for {len(audio_data)} tracks...") |
|
|
viz_dir = os.path.join(args.output_dir, "plots") |
|
|
for i, data in enumerate(tqdm(audio_data, desc="Visualizing")): |
|
|
time_range = tuple(args.time_range) if args.time_range else None |
|
|
visualize_track( |
|
|
data["audio"], |
|
|
data["sr"], |
|
|
data["pred"]["beats"], |
|
|
data["pred"]["downbeats"], |
|
|
data["gt"]["beats"], |
|
|
data["gt"]["downbeats"], |
|
|
viz_dir, |
|
|
i, |
|
|
time_range=time_range, |
|
|
) |
|
|
print(f"Saved visualizations to {viz_dir}") |
|
|
|
|
|
|
|
|
if args.synthesize: |
|
|
print(f"\nSynthesizing audio for {len(audio_data)} tracks...") |
|
|
audio_dir = os.path.join(args.output_dir, "audio") |
|
|
for i, data in enumerate(tqdm(audio_data, desc="Synthesizing")): |
|
|
synthesize_audio( |
|
|
data["audio"], |
|
|
data["sr"], |
|
|
data["pred"]["beats"], |
|
|
data["pred"]["downbeats"], |
|
|
data["gt"]["beats"], |
|
|
data["gt"]["downbeats"], |
|
|
audio_dir, |
|
|
i, |
|
|
click_volume=args.click_volume, |
|
|
) |
|
|
print(f"Saved audio files to {audio_dir}") |
|
|
print(" *_pred.wav - Original audio with predicted beat clicks") |
|
|
print(" *_gt.wav - Original audio with ground truth beat clicks") |
|
|
print(" *_both.wav - Original audio with both predicted and GT clicks") |
|
|
|
|
|
|
|
|
if args.summary_plot: |
|
|
from ..data.viz import plot_evaluation_summary, save_figure |
|
|
|
|
|
print("\nGenerating summary plot...") |
|
|
fig = plot_evaluation_summary(results, title="Beat Tracking Evaluation Summary") |
|
|
summary_path = os.path.join(args.output_dir, "evaluation_summary.png") |
|
|
save_figure(fig, summary_path) |
|
|
print(f"Saved summary plot to {summary_path}") |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|