import torch import numpy as np from tqdm import tqdm from scipy.signal import find_peaks import argparse import os from .model import SEResNeXt from ..baseline1.utils import MultiViewSpectrogram from ..data.load import ds from ..data.eval import evaluate_all, format_results def get_activation_function(model, waveform, device): """ Computes probability curve over time. """ processor = MultiViewSpectrogram().to(device) waveform = waveform.unsqueeze(0).to(device) with torch.no_grad(): spec = processor(waveform) # Normalize mean = spec.mean(dim=(2, 3), keepdim=True) std = spec.std(dim=(2, 3), keepdim=True) + 1e-6 spec = (spec - mean) / std # Batchify with sliding window # Context frames = 50, so total window = 101. # Pad time by 50 on each side. spec = torch.nn.functional.pad(spec, (50, 50)) # Pad time windows = spec.unfold(3, 101, 1) # (1, 3, 80, Time, 101) windows = windows.permute(3, 0, 1, 2, 4).squeeze(1) # (Time, 3, 80, 101) # Inference activations = [] batch_size = 128 # Reduced batch size for i in range(0, len(windows), batch_size): batch = windows[i : i + batch_size] out = model(batch) activations.append(out.cpu().numpy()) return np.concatenate(activations).flatten() def pick_peaks(activations, hop_length=160, sr=16000): """ Smooth with Hamming window and report local maxima. """ # Smoothing window = np.hamming(5) window /= window.sum() smoothed = np.convolve(activations, window, mode="same") # Peak Picking peaks, _ = find_peaks(smoothed, height=0.5, distance=5) timestamps = peaks * hop_length / sr return timestamps.tolist() def visualize_track( audio: np.ndarray, sr: int, pred_beats: list[float], pred_downbeats: list[float], gt_beats: list[float], gt_downbeats: list[float], output_dir: str, track_idx: int, time_range: tuple[float, float] | None = None, ): """ Create and save visualizations for a single track. """ from ..data.viz import plot_waveform_with_beats, save_figure os.makedirs(output_dir, exist_ok=True) # Full waveform plot fig = plot_waveform_with_beats( audio, sr, pred_beats, gt_beats, pred_downbeats, gt_downbeats, title=f"Track {track_idx}: Beat Comparison", time_range=time_range, ) save_figure(fig, os.path.join(output_dir, f"track_{track_idx:03d}.png")) def synthesize_audio( audio: np.ndarray, sr: int, pred_beats: list[float], pred_downbeats: list[float], gt_beats: list[float], gt_downbeats: list[float], output_dir: str, track_idx: int, click_volume: float = 0.5, ): """ Create and save audio files with click tracks for a single track. """ from ..data.audio import create_comparison_audio, save_audio os.makedirs(output_dir, exist_ok=True) # Create comparison audio audio_pred, audio_gt, audio_both = create_comparison_audio( audio, pred_beats, pred_downbeats, gt_beats, gt_downbeats, sr=sr, click_volume=click_volume, ) # Save audio files save_audio( audio_pred, os.path.join(output_dir, f"track_{track_idx:03d}_pred.wav"), sr ) save_audio(audio_gt, os.path.join(output_dir, f"track_{track_idx:03d}_gt.wav"), sr) save_audio( audio_both, os.path.join(output_dir, f"track_{track_idx:03d}_both.wav"), sr ) def main(): parser = argparse.ArgumentParser( description="Evaluate beat tracking models with visualization and audio synthesis" ) parser.add_argument( "--model-dir", type=str, default="outputs/baseline3", help="Base directory containing trained models (with 'beats' and 'downbeats' subdirs)", ) parser.add_argument( "--num-samples", type=int, default=116, help="Number of samples to evaluate", ) parser.add_argument( "--output-dir", type=str, default="outputs/eval_baseline3", help="Directory to save visualizations and audio", ) parser.add_argument( "--visualize", action="store_true", help="Generate visualization plots for each track", ) parser.add_argument( "--synthesize", action="store_true", help="Generate audio files with click tracks", ) parser.add_argument( "--viz-tracks", type=int, default=5, help="Number of tracks to visualize/synthesize (default: 5)", ) parser.add_argument( "--time-range", type=float, nargs=2, default=None, metavar=("START", "END"), help="Time range for visualization in seconds (default: full track)", ) parser.add_argument( "--click-volume", type=float, default=0.5, help="Volume of click sounds relative to audio (0.0 to 1.0)", ) parser.add_argument( "--summary-plot", action="store_true", help="Generate summary evaluation plot", ) args = parser.parse_args() DEVICE = "cuda" if torch.cuda.is_available() else "cpu" # Load BOTH models using from_pretrained beat_model = None downbeat_model = None has_beats = False has_downbeats = False beats_dir = os.path.join(args.model_dir, "beats") downbeats_dir = os.path.join(args.model_dir, "downbeats") if os.path.exists(os.path.join(beats_dir, "model.safetensors")): beat_model = SEResNeXt.from_pretrained(beats_dir).to(DEVICE) beat_model.eval() has_beats = True print(f"Loaded Beat Model from {beats_dir}") else: print(f"Warning: No beat model found in {beats_dir}") if os.path.exists(os.path.join(downbeats_dir, "model.safetensors")): downbeat_model = SEResNeXt.from_pretrained(downbeats_dir).to(DEVICE) downbeat_model.eval() has_downbeats = True print(f"Loaded Downbeat Model from {downbeats_dir}") else: print(f"Warning: No downbeat model found in {downbeats_dir}") if not has_beats and not has_downbeats: print("No models found. Please run training first.") return predictions = [] ground_truths = [] audio_data = [] # Store audio for visualization/synthesis # Eval on specified number of tracks test_set = ds["train"].select(range(args.num_samples)) print("Running evaluation...") for i, item in enumerate(tqdm(test_set)): waveform = torch.tensor(item["audio"]["array"], dtype=torch.float32) waveform_device = waveform.to(DEVICE) pred_entry = {"beats": [], "downbeats": []} # 1. Predict Beats if has_beats: act_b = get_activation_function(beat_model, waveform_device, DEVICE) pred_entry["beats"] = pick_peaks(act_b) # 2. Predict Downbeats if has_downbeats: act_d = get_activation_function(downbeat_model, waveform_device, DEVICE) pred_entry["downbeats"] = pick_peaks(act_d) predictions.append(pred_entry) ground_truths.append({"beats": item["beats"], "downbeats": item["downbeats"]}) # Store audio for later visualization/synthesis if args.visualize or args.synthesize: if i < args.viz_tracks: audio_data.append( { "audio": waveform.numpy(), "sr": item["audio"]["sampling_rate"], "pred": pred_entry, "gt": ground_truths[-1], } ) # Run evaluation results = evaluate_all(predictions, ground_truths) print(format_results(results)) # Create output directory if args.visualize or args.synthesize or args.summary_plot: os.makedirs(args.output_dir, exist_ok=True) # Generate visualizations if args.visualize: print(f"\nGenerating visualizations for {len(audio_data)} tracks...") viz_dir = os.path.join(args.output_dir, "plots") for i, data in enumerate(tqdm(audio_data, desc="Visualizing")): time_range = tuple(args.time_range) if args.time_range else None visualize_track( data["audio"], data["sr"], data["pred"]["beats"], data["pred"]["downbeats"], data["gt"]["beats"], data["gt"]["downbeats"], viz_dir, i, time_range=time_range, ) print(f"Saved visualizations to {viz_dir}") # Generate audio with clicks if args.synthesize: print(f"\nSynthesizing audio for {len(audio_data)} tracks...") audio_dir = os.path.join(args.output_dir, "audio") for i, data in enumerate(tqdm(audio_data, desc="Synthesizing")): synthesize_audio( data["audio"], data["sr"], data["pred"]["beats"], data["pred"]["downbeats"], data["gt"]["beats"], data["gt"]["downbeats"], audio_dir, i, click_volume=args.click_volume, ) print(f"Saved audio files to {audio_dir}") print(" *_pred.wav - Original audio with predicted beat clicks") print(" *_gt.wav - Original audio with ground truth beat clicks") print(" *_both.wav - Original audio with both predicted and GT clicks") # Generate summary plot if args.summary_plot: from ..data.viz import plot_evaluation_summary, save_figure print("\nGenerating summary plot...") fig = plot_evaluation_summary(results, title="Beat Tracking Evaluation Summary") summary_path = os.path.join(args.output_dir, "evaluation_summary.png") save_figure(fig, summary_path) print(f"Saved summary plot to {summary_path}") if __name__ == "__main__": main()