File size: 2,645 Bytes
31bf74c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
from datasets import load_dataset, Audio

N_PROC = None

ds = load_dataset("JacobLinCool/taiko-1000-parsed")
ds = ds.remove_columns(["tja", "hard", "normal", "easy", "ura"])


def filter_out_broken(example):
    try:
        example["audio"]["array"]
        return True
    except:
        return False


ds = ds.filter(filter_out_broken, num_proc=N_PROC, batch_size=32, writer_batch_size=32)
ds = ds.cast_column("audio", Audio(sampling_rate=16000))


def build_beat_and_downbeat_labels(example):
    """
    Extract beat and downbeat times from the chart segments.

    - Downbeats: First beat of each measure (segment timestamp)
    - Beats: All beats within each measure based on time signature

    Returns lists of times in seconds.
    """
    title = example["metadata"]["TITLE"]
    segments = example["oni"]["segments"]

    beats = []
    downbeats = []

    for i, segment in enumerate(segments):
        seg_timestamp = segment["timestamp"]
        measure_num = segment["measure_num"]  # numerator (e.g., 4 in 4/4)
        measure_den = segment["measure_den"]  # denominator (e.g., 4 in 4/4)
        notes = segment["notes"]

        # Downbeat is the start of each measure
        downbeats.append(seg_timestamp)

        # Get BPM from the first note in segment, or fallback to next segment's first note
        bpm = None
        if notes:
            bpm = notes[0]["bpm"]
        else:
            # Look ahead for BPM if current segment has no notes
            for j in range(i + 1, len(segments)):
                if segments[j]["notes"]:
                    bpm = segments[j]["notes"][0]["bpm"]
                    break

        if bpm is None or bpm <= 0:
            bpm = 120.0  # fallback default BPM

        # Calculate beat duration: one beat = 60/BPM seconds (for quarter note)
        # Adjust for time signature denominator (4 = quarter, 8 = eighth, etc.)
        beat_duration = (60.0 / bpm) * (4.0 / measure_den)

        # Calculate beat positions within this measure
        for beat_idx in range(measure_num):
            beat_time = seg_timestamp + beat_idx * beat_duration
            beats.append(beat_time)

    # Sort and deduplicate (in case of overlapping segments)
    beats = sorted(set(beats))
    downbeats = sorted(set(downbeats))

    return {
        "title": title,
        "beats": beats,
        "downbeats": downbeats,
    }


ds = ds.map(
    build_beat_and_downbeat_labels,
    num_proc=N_PROC,
    batch_size=32,
    writer_batch_size=32,
    remove_columns=["oni", "metadata"],
)

ds = ds.with_format("torch")

if __name__ == "__main__":
    print(ds)
    print(ds["train"].features)