File size: 2,645 Bytes
31bf74c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 |
from datasets import load_dataset, Audio
N_PROC = None
ds = load_dataset("JacobLinCool/taiko-1000-parsed")
ds = ds.remove_columns(["tja", "hard", "normal", "easy", "ura"])
def filter_out_broken(example):
try:
example["audio"]["array"]
return True
except:
return False
ds = ds.filter(filter_out_broken, num_proc=N_PROC, batch_size=32, writer_batch_size=32)
ds = ds.cast_column("audio", Audio(sampling_rate=16000))
def build_beat_and_downbeat_labels(example):
"""
Extract beat and downbeat times from the chart segments.
- Downbeats: First beat of each measure (segment timestamp)
- Beats: All beats within each measure based on time signature
Returns lists of times in seconds.
"""
title = example["metadata"]["TITLE"]
segments = example["oni"]["segments"]
beats = []
downbeats = []
for i, segment in enumerate(segments):
seg_timestamp = segment["timestamp"]
measure_num = segment["measure_num"] # numerator (e.g., 4 in 4/4)
measure_den = segment["measure_den"] # denominator (e.g., 4 in 4/4)
notes = segment["notes"]
# Downbeat is the start of each measure
downbeats.append(seg_timestamp)
# Get BPM from the first note in segment, or fallback to next segment's first note
bpm = None
if notes:
bpm = notes[0]["bpm"]
else:
# Look ahead for BPM if current segment has no notes
for j in range(i + 1, len(segments)):
if segments[j]["notes"]:
bpm = segments[j]["notes"][0]["bpm"]
break
if bpm is None or bpm <= 0:
bpm = 120.0 # fallback default BPM
# Calculate beat duration: one beat = 60/BPM seconds (for quarter note)
# Adjust for time signature denominator (4 = quarter, 8 = eighth, etc.)
beat_duration = (60.0 / bpm) * (4.0 / measure_den)
# Calculate beat positions within this measure
for beat_idx in range(measure_num):
beat_time = seg_timestamp + beat_idx * beat_duration
beats.append(beat_time)
# Sort and deduplicate (in case of overlapping segments)
beats = sorted(set(beats))
downbeats = sorted(set(downbeats))
return {
"title": title,
"beats": beats,
"downbeats": downbeats,
}
ds = ds.map(
build_beat_and_downbeat_labels,
num_proc=N_PROC,
batch_size=32,
writer_batch_size=32,
remove_columns=["oni", "metadata"],
)
ds = ds.with_format("torch")
if __name__ == "__main__":
print(ds)
print(ds["train"].features)
|