import torch from torch.utils.data import Dataset import numpy as np from tqdm import tqdm class BeatTrackingDataset(Dataset): def __init__( self, hf_dataset, target_type="beats", sample_rate=16000, hop_length=160, context_frames=50, ): """ Args: hf_dataset: HuggingFace dataset object target_type (str): "beats" or "downbeats". Determines which labels are treated as positive. context_frames (int): Number of frames before and after the center frame. Total frames = 2 * context_frames + 1. Default 50 means 101 frames (~1s). """ self.sr = sample_rate self.hop_length = hop_length self.target_type = target_type self.context_frames = context_frames # Context window size in samples # We need enough samples for the center frame +/- context frames # PLUS the window size of the largest FFT to compute the edges correctly. # Largest window in MultiViewSpectrogram is 1488. self.context_samples = (self.context_frames * 2 + 1) * hop_length + 1488 # Cache audio arrays in memory for fast access self.audio_cache = [] self.indices = [] self._prepare_indices(hf_dataset) def _prepare_indices(self, hf_dataset): """ Prepares balanced indices and caches audio. Uses the same "Fuzzier" training examples strategy as the baseline. """ print(f"Preparing dataset indices for target: {self.target_type}...") for i, item in tqdm( enumerate(hf_dataset), total=len(hf_dataset), desc="Building indices" ): # Cache audio array (convert to numpy if tensor) audio = item["audio"]["array"] if hasattr(audio, "numpy"): audio = audio.numpy() self.audio_cache.append(audio) # Calculate total frames available in audio audio_len = len(audio) n_frames = int(audio_len / self.hop_length) # Select ground truth based on target_type if self.target_type == "downbeats": gt_times = item["downbeats"] else: gt_times = item["beats"] # Convert to list if tensor if hasattr(gt_times, "tolist"): gt_times = gt_times.tolist() gt_frames = set([int(t * self.sr / self.hop_length) for t in gt_times]) # --- Positive Examples (with Fuzziness) --- pos_frames = set() for bf in gt_frames: if 0 <= bf < n_frames: self.indices.append((i, bf, 1.0)) # Center frame pos_frames.add(bf) # Neighbors weighted at 0.25 if 0 <= bf - 1 < n_frames: self.indices.append((i, bf - 1, 0.25)) pos_frames.add(bf - 1) if 0 <= bf + 1 < n_frames: self.indices.append((i, bf + 1, 0.25)) pos_frames.add(bf + 1) # --- Negative Examples --- # Balance 2:1 num_pos = len(pos_frames) num_neg = num_pos * 2 count = 0 attempts = 0 while count < num_neg and attempts < num_neg * 5: f = np.random.randint(0, n_frames) if f not in pos_frames: self.indices.append((i, f, 0.0)) count += 1 attempts += 1 print( f"Dataset ready. {len(self.indices)} samples, {len(self.audio_cache)} tracks cached." ) def __len__(self): return len(self.indices) def __getitem__(self, idx): track_idx, frame_idx, label = self.indices[idx] # Fast lookup from cache audio = self.audio_cache[track_idx] audio_len = len(audio) # Calculate sample range for context window center_sample = frame_idx * self.hop_length half_context = self.context_samples // 2 # We want the window centered around center_sample start = center_sample - half_context end = center_sample + half_context # Handle padding if needed pad_left = max(0, -start) pad_right = max(0, end - audio_len) valid_start = max(0, start) valid_end = min(audio_len, end) # Extract audio chunk chunk = audio[valid_start:valid_end] if pad_left > 0 or pad_right > 0: chunk = np.pad(chunk, (pad_left, pad_right), mode="constant") waveform = torch.tensor(chunk, dtype=torch.float32) return waveform, torch.tensor([label], dtype=torch.float32)