JacobLinCool's picture
Upload folder using huggingface_hub
707cbac unverified
import torch
import numpy as np
from tqdm import tqdm
from scipy.signal import find_peaks
import argparse
import os
from .model import ODCNN
from .utils import MultiViewSpectrogram
from ..data.load import ds
from ..data.eval import evaluate_all, format_results
def get_activation_function(model, waveform, device):
"""
Computes probability curve over time.
"""
processor = MultiViewSpectrogram().to(device)
waveform = waveform.unsqueeze(0).to(device)
with torch.no_grad():
spec = processor(waveform)
# Normalize
mean = spec.mean(dim=(2, 3), keepdim=True)
std = spec.std(dim=(2, 3), keepdim=True) + 1e-6
spec = (spec - mean) / std
# Batchify with sliding window
spec = torch.nn.functional.pad(spec, (7, 7)) # Pad time
windows = spec.unfold(3, 15, 1) # (1, 3, 80, Time, 15)
windows = windows.permute(3, 0, 1, 2, 4).squeeze(1) # (Time, 3, 80, 15)
# Inference
activations = []
batch_size = 512
for i in range(0, len(windows), batch_size):
batch = windows[i : i + batch_size]
out = model(batch)
activations.append(out.cpu().numpy())
return np.concatenate(activations).flatten()
def pick_peaks(activations, hop_length=160, sr=16000):
"""
Smooth with Hamming window and report local maxima.
"""
# Smoothing
window = np.hamming(5)
window /= window.sum()
smoothed = np.convolve(activations, window, mode="same")
# Peak Picking
peaks, _ = find_peaks(smoothed, height=0.5, distance=5)
timestamps = peaks * hop_length / sr
return timestamps.tolist()
def visualize_track(
audio: np.ndarray,
sr: int,
pred_beats: list[float],
pred_downbeats: list[float],
gt_beats: list[float],
gt_downbeats: list[float],
output_dir: str,
track_idx: int,
time_range: tuple[float, float] | None = None,
):
"""
Create and save visualizations for a single track.
"""
from ..data.viz import plot_waveform_with_beats, save_figure
os.makedirs(output_dir, exist_ok=True)
# Full waveform plot
fig = plot_waveform_with_beats(
audio,
sr,
pred_beats,
gt_beats,
pred_downbeats,
gt_downbeats,
title=f"Track {track_idx}: Beat Comparison",
time_range=time_range,
)
save_figure(fig, os.path.join(output_dir, f"track_{track_idx:03d}.png"))
def synthesize_audio(
audio: np.ndarray,
sr: int,
pred_beats: list[float],
pred_downbeats: list[float],
gt_beats: list[float],
gt_downbeats: list[float],
output_dir: str,
track_idx: int,
click_volume: float = 0.5,
):
"""
Create and save audio files with click tracks for a single track.
"""
from ..data.audio import create_comparison_audio, save_audio
os.makedirs(output_dir, exist_ok=True)
# Create comparison audio
audio_pred, audio_gt, audio_both = create_comparison_audio(
audio,
pred_beats,
pred_downbeats,
gt_beats,
gt_downbeats,
sr=sr,
click_volume=click_volume,
)
# Save audio files
save_audio(
audio_pred, os.path.join(output_dir, f"track_{track_idx:03d}_pred.wav"), sr
)
save_audio(audio_gt, os.path.join(output_dir, f"track_{track_idx:03d}_gt.wav"), sr)
save_audio(
audio_both, os.path.join(output_dir, f"track_{track_idx:03d}_both.wav"), sr
)
def main():
parser = argparse.ArgumentParser(
description="Evaluate beat tracking models with visualization and audio synthesis"
)
parser.add_argument(
"--model-dir",
type=str,
default="outputs/baseline1",
help="Base directory containing trained models (with 'beats' and 'downbeats' subdirs)",
)
parser.add_argument(
"--num-samples",
type=int,
default=116,
help="Number of samples to evaluate",
)
parser.add_argument(
"--output-dir",
type=str,
default="outputs/eval_baseline1",
help="Directory to save visualizations and audio",
)
parser.add_argument(
"--visualize",
action="store_true",
help="Generate visualization plots for each track",
)
parser.add_argument(
"--synthesize",
action="store_true",
help="Generate audio files with click tracks",
)
parser.add_argument(
"--viz-tracks",
type=int,
default=5,
help="Number of tracks to visualize/synthesize (default: 5)",
)
parser.add_argument(
"--time-range",
type=float,
nargs=2,
default=None,
metavar=("START", "END"),
help="Time range for visualization in seconds (default: full track)",
)
parser.add_argument(
"--click-volume",
type=float,
default=0.5,
help="Volume of click sounds relative to audio (0.0 to 1.0)",
)
parser.add_argument(
"--summary-plot",
action="store_true",
help="Generate summary evaluation plot",
)
args = parser.parse_args()
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
# Load BOTH models using from_pretrained
beat_model = None
downbeat_model = None
has_beats = False
has_downbeats = False
beats_dir = os.path.join(args.model_dir, "beats")
downbeats_dir = os.path.join(args.model_dir, "downbeats")
if os.path.exists(os.path.join(beats_dir, "model.safetensors")):
beat_model = ODCNN.from_pretrained(beats_dir).to(DEVICE)
beat_model.eval()
has_beats = True
print(f"Loaded Beat Model from {beats_dir}")
else:
print(f"Warning: No beat model found in {beats_dir}")
if os.path.exists(os.path.join(downbeats_dir, "model.safetensors")):
downbeat_model = ODCNN.from_pretrained(downbeats_dir).to(DEVICE)
downbeat_model.eval()
has_downbeats = True
print(f"Loaded Downbeat Model from {downbeats_dir}")
else:
print(f"Warning: No downbeat model found in {downbeats_dir}")
if not has_beats and not has_downbeats:
print("No models found. Please run training first.")
return
predictions = []
ground_truths = []
audio_data = [] # Store audio for visualization/synthesis
# Eval on specified number of tracks
test_set = ds["train"].select(range(args.num_samples))
print("Running evaluation...")
for i, item in enumerate(tqdm(test_set)):
waveform = torch.tensor(item["audio"]["array"], dtype=torch.float32)
waveform_device = waveform.to(DEVICE)
pred_entry = {"beats": [], "downbeats": []}
# 1. Predict Beats
if has_beats:
act_b = get_activation_function(beat_model, waveform_device, DEVICE)
pred_entry["beats"] = pick_peaks(act_b)
# 2. Predict Downbeats
if has_downbeats:
act_d = get_activation_function(downbeat_model, waveform_device, DEVICE)
pred_entry["downbeats"] = pick_peaks(act_d)
predictions.append(pred_entry)
ground_truths.append({"beats": item["beats"], "downbeats": item["downbeats"]})
# Store audio for later visualization/synthesis
if args.visualize or args.synthesize:
if i < args.viz_tracks:
audio_data.append(
{
"audio": waveform.numpy(),
"sr": item["audio"]["sampling_rate"],
"pred": pred_entry,
"gt": ground_truths[-1],
}
)
# Run evaluation
results = evaluate_all(predictions, ground_truths)
print(format_results(results))
# Create output directory
if args.visualize or args.synthesize or args.summary_plot:
os.makedirs(args.output_dir, exist_ok=True)
# Generate visualizations
if args.visualize:
print(f"\nGenerating visualizations for {len(audio_data)} tracks...")
viz_dir = os.path.join(args.output_dir, "plots")
for i, data in enumerate(tqdm(audio_data, desc="Visualizing")):
time_range = tuple(args.time_range) if args.time_range else None
visualize_track(
data["audio"],
data["sr"],
data["pred"]["beats"],
data["pred"]["downbeats"],
data["gt"]["beats"],
data["gt"]["downbeats"],
viz_dir,
i,
time_range=time_range,
)
print(f"Saved visualizations to {viz_dir}")
# Generate audio with clicks
if args.synthesize:
print(f"\nSynthesizing audio for {len(audio_data)} tracks...")
audio_dir = os.path.join(args.output_dir, "audio")
for i, data in enumerate(tqdm(audio_data, desc="Synthesizing")):
synthesize_audio(
data["audio"],
data["sr"],
data["pred"]["beats"],
data["pred"]["downbeats"],
data["gt"]["beats"],
data["gt"]["downbeats"],
audio_dir,
i,
click_volume=args.click_volume,
)
print(f"Saved audio files to {audio_dir}")
print(" *_pred.wav - Original audio with predicted beat clicks")
print(" *_gt.wav - Original audio with ground truth beat clicks")
print(" *_both.wav - Original audio with both predicted and GT clicks")
# Generate summary plot
if args.summary_plot:
from ..data.viz import plot_evaluation_summary, save_figure
print("\nGenerating summary plot...")
fig = plot_evaluation_summary(results, title="Beat Tracking Evaluation Summary")
summary_path = os.path.join(args.output_dir, "evaluation_summary.png")
save_figure(fig, summary_path)
print(f"Saved summary plot to {summary_path}")
if __name__ == "__main__":
main()