JacobLinCool's picture
Upload folder using huggingface_hub
707cbac unverified
import torch
from torch.utils.data import Dataset
import numpy as np
from tqdm import tqdm
from .utils import extract_context
class BeatTrackingDataset(Dataset):
def __init__(
self, hf_dataset, target_type="beats", sample_rate=16000, hop_length=160
):
"""
Args:
hf_dataset: HuggingFace dataset object
target_type (str): "beats" or "downbeats". Determines which labels are treated as positive.
"""
self.sr = sample_rate
self.hop_length = hop_length
self.target_type = target_type
# Context window size in samples (7 frames = 70ms at 100fps)
self.context_frames = 7
self.context_samples = (self.context_frames * 2 + 1) * hop_length + max(
[368, 736, 1488]
) # extra for FFT window
# Cache audio arrays in memory for fast access
self.audio_cache = []
self.indices = []
self._prepare_indices(hf_dataset)
def _prepare_indices(self, hf_dataset):
"""
Prepares balanced indices and caches audio.
Paper Section 4.5: Uses "Fuzzier" training examples (neighbors weighted less).
"""
print(f"Preparing dataset indices for target: {self.target_type}...")
for i, item in tqdm(
enumerate(hf_dataset), total=len(hf_dataset), desc="Building indices"
):
# Cache audio array (convert to numpy if tensor)
audio = item["audio"]["array"]
if hasattr(audio, "numpy"):
audio = audio.numpy()
self.audio_cache.append(audio)
# Calculate total frames available in audio
audio_len = len(audio)
n_frames = int(audio_len / self.hop_length)
# Select ground truth based on target_type
if self.target_type == "downbeats":
# Only downbeats are positives
gt_times = item["downbeats"]
else:
# All beats are positives (downbeats are also beats)
gt_times = item["beats"]
# Convert to list if tensor
if hasattr(gt_times, "tolist"):
gt_times = gt_times.tolist()
gt_frames = set([int(t * self.sr / self.hop_length) for t in gt_times])
# --- Positive Examples (with Fuzziness) ---
# "define a single frame before and after each annotated onset to be additional positive examples"
pos_frames = set()
for bf in gt_frames:
if 0 <= bf < n_frames:
self.indices.append((i, bf, 1.0)) # Center frame (Sharp onset)
pos_frames.add(bf)
# Neighbors weighted at 0.25
if 0 <= bf - 1 < n_frames:
self.indices.append((i, bf - 1, 0.25))
pos_frames.add(bf - 1)
if 0 <= bf + 1 < n_frames:
self.indices.append((i, bf + 1, 0.25))
pos_frames.add(bf + 1)
# --- Negative Examples ---
# Paper uses "all others as negative", but we balance 2:1 for stable SGD.
num_pos = len(pos_frames)
num_neg = num_pos * 2
count = 0
attempts = 0
while count < num_neg and attempts < num_neg * 5:
f = np.random.randint(0, n_frames)
if f not in pos_frames:
self.indices.append((i, f, 0.0))
count += 1
attempts += 1
print(
f"Dataset ready. {len(self.indices)} samples, {len(self.audio_cache)} tracks cached."
)
def __len__(self):
return len(self.indices)
def __getitem__(self, idx):
track_idx, frame_idx, label = self.indices[idx]
# Fast lookup from cache
audio = self.audio_cache[track_idx]
audio_len = len(audio)
# Calculate sample range for context window
center_sample = frame_idx * self.hop_length
half_context = self.context_samples // 2
start = center_sample - half_context
end = center_sample + half_context
# Handle padding if needed
pad_left = max(0, -start)
pad_right = max(0, end - audio_len)
start = max(0, start)
end = min(audio_len, end)
# Extract audio chunk
chunk = audio[start:end]
if pad_left > 0 or pad_right > 0:
chunk = np.pad(chunk, (pad_left, pad_right), mode="constant")
waveform = torch.tensor(chunk, dtype=torch.float32)
return waveform, torch.tensor([label], dtype=torch.float32)