OliBomby commited on
Commit
7b967f9
·
verified ·
1 Parent(s): 9bbf03d

Add CM3P model

Browse files
audio_feature_extractor/preprocessor_config.json ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "auto_map": {
3
+ "AutoProcessor": "processing_cm3p.CM3PProcessor"
4
+ },
5
+ "chunk_length": 30,
6
+ "dither": 0.0,
7
+ "feature_extractor_type": "WhisperFeatureExtractor",
8
+ "feature_size": 80,
9
+ "hop_length": 160,
10
+ "n_fft": 400,
11
+ "n_samples": 480000,
12
+ "nb_max_frames": 3000,
13
+ "padding_side": "right",
14
+ "padding_value": 0.0,
15
+ "processor_class": "CM3PProcessor",
16
+ "return_attention_mask": false,
17
+ "sampling_rate": 16000
18
+ }
beatmap_parser/preprocessor_config.json ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_distances": false,
3
+ "add_hitsounds": true,
4
+ "add_kiai": true,
5
+ "add_mania_sv": true,
6
+ "add_positions": true,
7
+ "add_snapping": false,
8
+ "add_sv": true,
9
+ "add_timing": true,
10
+ "add_timing_points": true,
11
+ "auto_map": {
12
+ "AutoProcessor": "processing_cm3p.CM3PProcessor"
13
+ },
14
+ "feature_extractor_type": "CM3PBeatmapParser",
15
+ "mania_bpm_normalized_scroll_speed": true,
16
+ "processor_class": "CM3PProcessor",
17
+ "slider_version": 2
18
+ }
beatmap_tokenizer/special_tokens_map.json ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "additional_special_tokens": [
3
+ "[AUDIO_BOS]",
4
+ "[AUDIO_EOS]",
5
+ "[AUDIO]"
6
+ ],
7
+ "bos_token": {
8
+ "content": "[BOS]",
9
+ "lstrip": false,
10
+ "normalized": false,
11
+ "rstrip": false,
12
+ "single_word": false
13
+ },
14
+ "cls_token": {
15
+ "content": "[CLS]",
16
+ "lstrip": false,
17
+ "normalized": false,
18
+ "rstrip": false,
19
+ "single_word": false
20
+ },
21
+ "eos_token": {
22
+ "content": "[EOS]",
23
+ "lstrip": false,
24
+ "normalized": false,
25
+ "rstrip": false,
26
+ "single_word": false
27
+ },
28
+ "mask_token": {
29
+ "content": "[MASK]",
30
+ "lstrip": false,
31
+ "normalized": false,
32
+ "rstrip": false,
33
+ "single_word": false
34
+ },
35
+ "pad_token": {
36
+ "content": "[PAD]",
37
+ "lstrip": false,
38
+ "normalized": false,
39
+ "rstrip": false,
40
+ "single_word": false
41
+ },
42
+ "sep_token": {
43
+ "content": "[SEP]",
44
+ "lstrip": false,
45
+ "normalized": false,
46
+ "rstrip": false,
47
+ "single_word": false
48
+ },
49
+ "unk_token": {
50
+ "content": "[UNK]",
51
+ "lstrip": false,
52
+ "normalized": false,
53
+ "rstrip": false,
54
+ "single_word": false
55
+ }
56
+ }
beatmap_tokenizer/tokenizer_config.json ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_cls_token": true,
3
+ "added_tokens_decoder": {
4
+ "3958": {
5
+ "content": "[BOS]",
6
+ "lstrip": false,
7
+ "normalized": false,
8
+ "rstrip": false,
9
+ "single_word": false,
10
+ "special": true
11
+ },
12
+ "3959": {
13
+ "content": "[EOS]",
14
+ "lstrip": false,
15
+ "normalized": false,
16
+ "rstrip": false,
17
+ "single_word": false,
18
+ "special": true
19
+ },
20
+ "3960": {
21
+ "content": "[UNK]",
22
+ "lstrip": false,
23
+ "normalized": false,
24
+ "rstrip": false,
25
+ "single_word": false,
26
+ "special": true
27
+ },
28
+ "3961": {
29
+ "content": "[SEP]",
30
+ "lstrip": false,
31
+ "normalized": false,
32
+ "rstrip": false,
33
+ "single_word": false,
34
+ "special": true
35
+ },
36
+ "3962": {
37
+ "content": "[PAD]",
38
+ "lstrip": false,
39
+ "normalized": false,
40
+ "rstrip": false,
41
+ "single_word": false,
42
+ "special": true
43
+ },
44
+ "3963": {
45
+ "content": "[CLS]",
46
+ "lstrip": false,
47
+ "normalized": false,
48
+ "rstrip": false,
49
+ "single_word": false,
50
+ "special": true
51
+ },
52
+ "3964": {
53
+ "content": "[MASK]",
54
+ "lstrip": false,
55
+ "normalized": false,
56
+ "rstrip": false,
57
+ "single_word": false,
58
+ "special": true
59
+ },
60
+ "3965": {
61
+ "content": "[AUDIO_BOS]",
62
+ "lstrip": false,
63
+ "normalized": false,
64
+ "rstrip": false,
65
+ "single_word": false,
66
+ "special": true
67
+ },
68
+ "3966": {
69
+ "content": "[AUDIO_EOS]",
70
+ "lstrip": false,
71
+ "normalized": false,
72
+ "rstrip": false,
73
+ "single_word": false,
74
+ "special": true
75
+ },
76
+ "3967": {
77
+ "content": "[AUDIO]",
78
+ "lstrip": false,
79
+ "normalized": false,
80
+ "rstrip": false,
81
+ "single_word": false,
82
+ "special": true
83
+ }
84
+ },
85
+ "additional_special_tokens": [
86
+ "[AUDIO_BOS]",
87
+ "[AUDIO_EOS]",
88
+ "[AUDIO]"
89
+ ],
90
+ "auto_map": {
91
+ "AutoProcessor": "processing_cm3p.CM3PProcessor"
92
+ },
93
+ "bos_token": "[BOS]",
94
+ "clean_up_tokenization_spaces": false,
95
+ "cls_token": "[CLS]",
96
+ "distance_step": 4,
97
+ "eos_token": "[EOS]",
98
+ "extra_special_tokens": {},
99
+ "mask_token": "[MASK]",
100
+ "max_distance": 640,
101
+ "max_time": 16000,
102
+ "min_time": 0,
103
+ "model_max_length": 1000000000000000019884624838656,
104
+ "pad_token": "[PAD]",
105
+ "position_range": [
106
+ -256,
107
+ 768,
108
+ -256,
109
+ 640
110
+ ],
111
+ "position_split_axes": true,
112
+ "position_step": 4,
113
+ "processor_class": "CM3PProcessor",
114
+ "sep_token": "[SEP]",
115
+ "separate_new_combo_token": false,
116
+ "time_step": 10,
117
+ "tokenizer_class": "CM3PBeatmapTokenizer",
118
+ "unk_token": "[UNK]"
119
+ }
beatmap_tokenizer/vocab.json ADDED
The diff for this file is too large to render. See raw diff
 
metadata_tokenizer/special_tokens_map.json ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "additional_special_tokens": [
3
+ "[DIFFICULTY_UNK]",
4
+ "[YEAR_UNK]",
5
+ "[MODE_UNK]",
6
+ "[STATUS_UNK]",
7
+ "[MAPPER_UNK]",
8
+ "[CS_UNK]",
9
+ "[HITSOUNDED_UNK]",
10
+ "[SONG_LENGTH_UNK]",
11
+ "[SONG_POSITION_UNK]",
12
+ "[GLOBAL_SV_UNK]",
13
+ "[MANIA_KEYCOUNT_UNK]",
14
+ "[HOLD_NOTE_RATIO_UNK]",
15
+ "[SCROLL_SPEED_RATIO_UNK]",
16
+ "[TAG_UNK]"
17
+ ],
18
+ "bos_token": {
19
+ "content": "[BOS]",
20
+ "lstrip": false,
21
+ "normalized": false,
22
+ "rstrip": false,
23
+ "single_word": false
24
+ },
25
+ "cls_token": {
26
+ "content": "[CLS]",
27
+ "lstrip": false,
28
+ "normalized": false,
29
+ "rstrip": false,
30
+ "single_word": false
31
+ },
32
+ "eos_token": {
33
+ "content": "[EOS]",
34
+ "lstrip": false,
35
+ "normalized": false,
36
+ "rstrip": false,
37
+ "single_word": false
38
+ },
39
+ "pad_token": {
40
+ "content": "[PAD]",
41
+ "lstrip": false,
42
+ "normalized": false,
43
+ "rstrip": false,
44
+ "single_word": false
45
+ }
46
+ }
metadata_tokenizer/tokenizer_config.json ADDED
The diff for this file is too large to render. See raw diff
 
metadata_tokenizer/vocab.json ADDED
The diff for this file is too large to render. See raw diff
 
parsing_cm3p.py ADDED
@@ -0,0 +1,757 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import dataclasses
2
+ from datetime import timedelta
3
+ from enum import Enum
4
+ from os import PathLike
5
+ from typing import Optional, Union, IO
6
+
7
+ import numpy as np
8
+ import numpy.typing as npt
9
+ from slider import Beatmap, Circle, Slider, Spinner, HoldNote, TimingPoint
10
+ from slider.curve import Linear, Catmull, Perfect, MultiBezier
11
+ from transformers import FeatureExtractionMixin, AutoFeatureExtractor
12
+
13
+ from .configuration_cm3p import CM3PConfig
14
+
15
+
16
+ class EventType(Enum):
17
+ CIRCLE = "circle"
18
+ SPINNER = "spinner"
19
+ SPINNER_END = "spinner_end"
20
+ SLIDER_HEAD = "slider_head"
21
+ BEZIER_ANCHOR = "bezier_anchor"
22
+ PERFECT_ANCHOR = "perfect_anchor"
23
+ CATMULL_ANCHOR = "catmull_anchor"
24
+ RED_ANCHOR = "red_anchor"
25
+ LAST_ANCHOR = "last_anchor"
26
+ SLIDER_END = "slider_end"
27
+ REPEAT_END = "repeat_end"
28
+ BEAT = "beat"
29
+ MEASURE = "measure"
30
+ TIMING_POINT = "timing_point"
31
+ KIAI_ON = "kiai_on"
32
+ KIAI_OFF = "kiai_off"
33
+ HOLD_NOTE = "hold_note"
34
+ HOLD_NOTE_END = "hold_note_end"
35
+ SCROLL_SPEED_CHANGE = "scroll_speed_change"
36
+ DRUMROLL = "drumroll"
37
+ DRUMROLL_END = "drumroll_end"
38
+ DENDEN = "denden"
39
+ DENDEN_END = "denden_end"
40
+
41
+
42
+ EVENT_TYPES_WITH_NEW_COMBO = [
43
+ EventType.CIRCLE,
44
+ EventType.SLIDER_HEAD,
45
+ ]
46
+
47
+
48
+ @dataclasses.dataclass
49
+ class Group:
50
+ event_type: EventType = None
51
+ time: int = 0
52
+ has_time: bool = False
53
+ snapping: int = None
54
+ distance: int = None
55
+ x: int = None
56
+ y: int = None
57
+ mania_column: int = None
58
+ new_combo: bool = False
59
+ hitsounds: list[int] = dataclasses.field(default_factory=list)
60
+ samplesets: list[int] = dataclasses.field(default_factory=list)
61
+ additions: list[int] = dataclasses.field(default_factory=list)
62
+ volumes: list[int] = dataclasses.field(default_factory=list)
63
+ scroll_speed: float = None
64
+
65
+
66
+ def merge_groups(groups1: list[Group], groups2: list[Group]) -> list[Group]:
67
+ """Merge two lists of groups in a time sorted manner. Assumes both lists are sorted by time.
68
+
69
+ Args:
70
+ groups1: List of groups.
71
+ groups2: List of groups.
72
+
73
+ Returns:
74
+ merged_groups: Merged list of groups.
75
+ """
76
+ merged_groups = []
77
+ i = 0
78
+ j = 0
79
+ t1 = -np.inf
80
+ t2 = -np.inf
81
+
82
+ while i < len(groups1) and j < len(groups2):
83
+ t1 = groups1[i].time or t1
84
+ t2 = groups2[j].time or t2
85
+
86
+ if t1 <= t2:
87
+ merged_groups.append(groups1[i])
88
+ i += 1
89
+ else:
90
+ merged_groups.append(groups2[j])
91
+ j += 1
92
+
93
+ # Add remaining groups from both lists
94
+ merged_groups.extend(groups1[i:])
95
+ merged_groups.extend(groups2[j:])
96
+ return merged_groups
97
+
98
+
99
+ def speed_groups(groups: list[Group], speed: float) -> list[Group]:
100
+ """Change the speed of a list of groups.
101
+
102
+ Args:
103
+ groups: List of groups.
104
+ speed: Speed multiplier.
105
+
106
+ Returns:
107
+ sped_groups: Sped up list of groups.
108
+ """
109
+ sped_groups = []
110
+ for group in groups:
111
+ group.time = int(group.time / speed)
112
+ sped_groups.append(group)
113
+
114
+ return sped_groups
115
+
116
+
117
+ def get_median_mpb_beatmap(beatmap: Beatmap) -> float:
118
+ # Not include last slider's end time
119
+ last_time = max(ho.end_time if isinstance(ho, HoldNote) else ho.time for ho in beatmap.hit_objects(stacking=False))
120
+ last_time = int(last_time.seconds * 1000)
121
+ return get_median_mpb(beatmap.timing_points, last_time)
122
+
123
+
124
+ def get_median_mpb(timing_points: list[TimingPoint], last_time: float) -> float:
125
+ # This is identical to osu! stable implementation
126
+ this_beat_length = 0
127
+
128
+ bpm_durations = {}
129
+
130
+ for i in range(len(timing_points) - 1, -1, -1):
131
+ tp = timing_points[i]
132
+ offset = int(tp.offset.seconds * 1000)
133
+
134
+ if tp.parent is None:
135
+ this_beat_length = tp.ms_per_beat
136
+
137
+ if this_beat_length == 0 or offset > last_time or (tp.parent is not None and i > 0):
138
+ continue
139
+
140
+ if this_beat_length in bpm_durations:
141
+ bpm_durations[this_beat_length] += int(last_time - (0 if i == 0 else offset))
142
+ else:
143
+ bpm_durations[this_beat_length] = int(last_time - (0 if i == 0 else offset))
144
+
145
+ last_time = offset
146
+
147
+ longest_time = 0
148
+ median = 0
149
+
150
+ for bpm, duration in bpm_durations.items():
151
+ if duration > longest_time:
152
+ longest_time = duration
153
+ median = bpm
154
+
155
+ return median
156
+
157
+
158
+ def load_beatmap(beatmap: Union[str, PathLike, IO[str], Beatmap]) -> Beatmap:
159
+ """Load a beatmap from a file path, file object, or Beatmap object.
160
+
161
+ Args:
162
+ beatmap: Beatmap file path, file object, or Beatmap object.
163
+
164
+ Returns:
165
+ beatmap: Loaded Beatmap object.
166
+ """
167
+ if isinstance(beatmap, (str, PathLike)):
168
+ beatmap = Beatmap.from_path(beatmap)
169
+ elif isinstance(beatmap, IO):
170
+ beatmap = Beatmap.from_file(beatmap.name)
171
+ return beatmap
172
+
173
+
174
+ def get_song_length(
175
+ samples: np.ndarray = None,
176
+ sample_rate: int = None,
177
+ beatmap: Union[Beatmap | list[TimingPoint]] = None,
178
+ ) -> float:
179
+ if samples is not None and sample_rate is not None:
180
+ return len(samples) / sample_rate
181
+
182
+ if beatmap is None:
183
+ return 0
184
+
185
+ if isinstance(beatmap, Beatmap) and len(beatmap.hit_objects(stacking=False)) > 0:
186
+ last_ho = beatmap.hit_objects(stacking=False)[-1]
187
+ last_time = last_ho.end_time if hasattr(last_ho, "end_time") else last_ho.time
188
+ return last_time.total_seconds() + 0.000999 # Add a small buffer to the last time
189
+
190
+ timing = beatmap.timing_points if isinstance(beatmap, Beatmap) else beatmap
191
+ if len(timing) == 0:
192
+ return 0
193
+
194
+ return timing[-1].offset.total_seconds() + 0.01
195
+
196
+
197
+ class CM3PBeatmapParser(FeatureExtractionMixin):
198
+ """
199
+ A class to parse CM3P beatmap files.
200
+ """
201
+ def __init__(
202
+ self,
203
+ add_timing: bool = True,
204
+ add_snapping: bool = True,
205
+ add_timing_points: bool = True,
206
+ add_hitsounds: bool = True,
207
+ add_distances: bool = True,
208
+ add_positions: bool = True,
209
+ add_kiai: bool = True,
210
+ add_sv: bool = True,
211
+ add_mania_sv: bool = True,
212
+ mania_bpm_normalized_scroll_speed: bool = True,
213
+ slider_version: int = 2,
214
+ **kwargs,
215
+ ):
216
+ self.add_timing = add_timing
217
+ self.add_snapping = add_snapping
218
+ self.add_timing_points = add_timing_points
219
+ self.add_hitsounds = add_hitsounds
220
+ self.add_distances = add_distances
221
+ self.add_positions = add_positions
222
+ self.add_kiai = add_kiai
223
+ self.add_sv = add_sv
224
+ self.add_mania_sv = add_mania_sv
225
+ self.mania_bpm_normalized_scroll_speed = mania_bpm_normalized_scroll_speed
226
+ self.slider_version = slider_version
227
+ super().__init__(**kwargs)
228
+
229
+ def parse_beatmap(
230
+ self,
231
+ beatmap: Union[str, PathLike, IO[str], Beatmap],
232
+ speed: float = 1.0,
233
+ song_length: Optional[float] = None
234
+ ) -> list[Group]:
235
+ """Parse an .osu beatmap.
236
+
237
+ Each hit object is parsed into a list of Event objects, in order of its
238
+ appearance in the beatmap. In other words, in ascending order of time.
239
+
240
+ Args:
241
+ beatmap: Beatmap object parsed from an .osu file.
242
+ speed: Speed multiplier for the beatmap.
243
+ song_length: Length of the song in seconds. If not provided, it will be calculated from the beatmap.
244
+
245
+ Returns:
246
+ events: List of Event object lists.
247
+ event_times: List of event times.
248
+ """
249
+ beatmap = load_beatmap(beatmap)
250
+ hit_objects = beatmap.hit_objects(stacking=False)
251
+ last_pos = np.array((256, 192))
252
+ groups = []
253
+
254
+ for hit_object in hit_objects:
255
+ if isinstance(hit_object, Circle):
256
+ last_pos = self._parse_circle(hit_object, groups, last_pos, beatmap)
257
+ elif isinstance(hit_object, Slider):
258
+ if beatmap.mode == 1:
259
+ self._parse_drumroll(hit_object, groups, beatmap)
260
+ else:
261
+ last_pos = self._parse_slider(hit_object, groups, last_pos, beatmap)
262
+ elif isinstance(hit_object, Spinner):
263
+ if beatmap.mode == 1:
264
+ self._parse_denden(hit_object, groups, beatmap)
265
+ else:
266
+ last_pos = self._parse_spinner(hit_object, groups, beatmap)
267
+ elif isinstance(hit_object, HoldNote):
268
+ last_pos = self._parse_hold_note(hit_object, groups, beatmap)
269
+
270
+ # Sort groups by time
271
+ if len(groups) > 0:
272
+ groups = sorted(groups, key=lambda x: x.time)
273
+ result = list(groups)
274
+
275
+ if self.add_mania_sv and beatmap.mode == 3:
276
+ scroll_speed_events = self.parse_scroll_speeds(beatmap)
277
+ result = merge_groups(scroll_speed_events, result)
278
+
279
+ if self.add_kiai:
280
+ kiai_events = self.parse_kiai(beatmap)
281
+ result = merge_groups(kiai_events, result)
282
+
283
+ if self.add_timing:
284
+ timing_events = self.parse_timing(beatmap, song_length=song_length)
285
+ result = merge_groups(timing_events, result)
286
+
287
+ if speed != 1.0:
288
+ result = speed_groups(result, speed)
289
+
290
+ return result
291
+
292
+ def parse_scroll_speeds(self, beatmap: Beatmap, speed: float = 1.0) -> list[Group]:
293
+ """Extract all BPM-normalized scroll speed changes from a beatmap."""
294
+ normalized = self.mania_bpm_normalized_scroll_speed
295
+ groups = []
296
+ median_mpb = get_median_mpb_beatmap(beatmap)
297
+ mpb = median_mpb
298
+ last_normalized_scroll_speed = -1
299
+
300
+ for i, tp in enumerate(beatmap.timing_points):
301
+ if tp.parent is None:
302
+ mpb = tp.ms_per_beat
303
+ scroll_speed = 1
304
+ else:
305
+ scroll_speed = -100 / tp.ms_per_beat
306
+
307
+ if i == len(beatmap.timing_points) - 1 or beatmap.timing_points[i + 1].offset > tp.offset:
308
+ normalized_scroll_speed = scroll_speed * median_mpb / mpb if normalized else scroll_speed
309
+
310
+ if normalized_scroll_speed != last_normalized_scroll_speed or last_normalized_scroll_speed == -1:
311
+ self._add_group(
312
+ EventType.SCROLL_SPEED_CHANGE,
313
+ groups,
314
+ time=tp.offset,
315
+ beatmap=beatmap,
316
+ scroll_speed=normalized_scroll_speed,
317
+ )
318
+ last_normalized_scroll_speed = normalized_scroll_speed
319
+
320
+ if speed != 1.0:
321
+ groups = speed_groups(groups, speed)
322
+
323
+ return groups
324
+
325
+ def parse_kiai(self, beatmap: Beatmap, speed: float = 1.0) -> list[Group]:
326
+ """Extract all kiai information from a beatmap."""
327
+ groups = []
328
+ kiai = False
329
+
330
+ for tp in beatmap.timing_points:
331
+ if tp.kiai_mode == kiai:
332
+ continue
333
+
334
+ self._add_group(
335
+ EventType.KIAI_ON if tp.kiai_mode else EventType.KIAI_OFF,
336
+ groups,
337
+ time=tp.offset,
338
+ beatmap=beatmap,
339
+ )
340
+ kiai = tp.kiai_mode
341
+
342
+ if speed != 1.0:
343
+ groups = speed_groups(groups, speed)
344
+
345
+ return groups
346
+
347
+ def parse_timing(self, beatmap: Beatmap | list[TimingPoint], speed: float = 1.0, song_length: Optional[float] = None) -> list[Group]:
348
+ """Extract all timing information from a beatmap."""
349
+ timing = beatmap.timing_points if isinstance(beatmap, Beatmap) else beatmap
350
+ assert len(timing) > 0, "No timing points found in beatmap."
351
+
352
+ groups = []
353
+ last_time = song_length or get_song_length(beatmap=beatmap)
354
+ last_time = int(last_time * 1000)
355
+
356
+ # Get all timing points with BPM changes
357
+ timing_points = [tp for tp in timing if tp.bpm]
358
+
359
+ for i, tp in enumerate(timing_points):
360
+ # Generate beat and measure events until the next timing point
361
+ next_tp = timing_points[i + 1] if i + 1 < len(timing_points) else None
362
+ next_time = next_tp.offset.total_seconds() * 1000 - 10 if next_tp else last_time
363
+ start_time = tp.offset.total_seconds() * 1000
364
+ time = start_time
365
+ measure_counter = 0
366
+ beat_delta = tp.ms_per_beat
367
+ while time <= next_time:
368
+ if self.add_timing_points and measure_counter == 0:
369
+ event_type = EventType.TIMING_POINT
370
+ elif measure_counter % tp.meter == 0:
371
+ event_type = EventType.MEASURE
372
+ else:
373
+ event_type = EventType.BEAT
374
+
375
+ self._add_group(
376
+ event_type,
377
+ groups,
378
+ time=timedelta(milliseconds=time),
379
+ add_snap=False,
380
+ )
381
+
382
+ # Exit early if the beat_delta is too small to avoid infinite loops
383
+ if beat_delta <= 10:
384
+ break
385
+
386
+ measure_counter += 1
387
+ time = start_time + measure_counter * beat_delta
388
+
389
+ if speed != 1.0:
390
+ groups = speed_groups(groups, speed)
391
+
392
+ return groups
393
+
394
+ @staticmethod
395
+ def uninherited_point_at(time: timedelta, beatmap: Beatmap):
396
+ tp = beatmap.timing_point_at(time)
397
+ return tp if tp.parent is None else tp.parent
398
+
399
+ @staticmethod
400
+ def hitsound_point_at(time: timedelta, beatmap: Beatmap):
401
+ hs_query = time + timedelta(milliseconds=5)
402
+ return beatmap.timing_point_at(hs_query)
403
+
404
+ def scroll_speed_at(self, time: timedelta, beatmap: Beatmap) -> float:
405
+ query = time
406
+ tp = beatmap.timing_point_at(query)
407
+ return self.tp_to_scroll_speed(tp)
408
+
409
+ def tp_to_scroll_speed(self, tp: TimingPoint) -> float:
410
+ if tp.parent is None or tp.ms_per_beat >= 0 or np.isnan(tp.ms_per_beat):
411
+ return 1
412
+ else:
413
+ return np.clip(-100 / tp.ms_per_beat, 0.01, 10)
414
+
415
+ def _get_snapping(self, time: timedelta, beatmap: Beatmap, add_snap: bool = True) -> int:
416
+ """Add a snapping event to the event list.
417
+
418
+ Args:
419
+ time: Time of the snapping event.
420
+ beatmap: Beatmap object.
421
+ add_snap: Whether to add a snapping event.
422
+ """
423
+ if not add_snap or not self.add_snapping:
424
+ return None
425
+
426
+ tp = self.uninherited_point_at(time, beatmap)
427
+ beats = (time - tp.offset).total_seconds() * 1000 / tp.ms_per_beat
428
+ snapping = 0
429
+ for i in range(1, 17):
430
+ # If the difference between the time and the snapped time is less than 2 ms, that is the correct snapping
431
+ if abs(beats - round(beats * i) / i) * tp.ms_per_beat < 2:
432
+ snapping = i
433
+ break
434
+
435
+ return snapping
436
+
437
+ def _get_hitsounds(self, time: timedelta, hitsound: int, addition: str, beatmap: Beatmap) -> tuple[int, int, int, int]:
438
+ tp = self.hitsound_point_at(time, beatmap)
439
+ tp_sample_set = tp.sample_type if tp.sample_type != 0 else 2 # Inherit to soft sample set
440
+ addition_split = addition.split(":")
441
+ sample_set = int(addition_split[0]) if addition_split[0] != "0" else tp_sample_set
442
+ addition_set = int(addition_split[1]) if addition_split[1] != "0" else sample_set
443
+ volume = int(addition_split[3]) if len(addition_split) > 3 and addition_split[3] != "0" else tp.volume
444
+
445
+ sample_set = sample_set if 0 < sample_set < 4 else 1 # Overflow default to normal sample set
446
+ addition_set = addition_set if 0 < addition_set < 4 else 1 # Overflow default to normal sample set
447
+ hitsound = hitsound & 14 # Only take the bits for whistle, finish, and clap
448
+ volume = np.clip(volume, 0, 100)
449
+
450
+ return hitsound, sample_set, addition_set, volume
451
+
452
+ def _get_position(self, pos: npt.NDArray, last_pos: npt.NDArray) -> tuple[int, int, int, npt.NDArray]:
453
+ x, y, dist = None, None, None
454
+
455
+ if self.add_distances:
456
+ dist = int(np.linalg.norm(pos - last_pos))
457
+
458
+ if self.add_positions:
459
+ x = int(pos[0])
460
+ y = int(pos[1])
461
+
462
+ return x, y, dist, pos
463
+
464
+ def _get_mania_column(self, pos: npt.NDArray, columns: int) -> int:
465
+ column = int(np.clip(pos[0] / 512 * columns, 0, columns - 1))
466
+ return column
467
+
468
+ def _add_group(
469
+ self,
470
+ event_type: EventType,
471
+ groups: list[Group],
472
+ time: timedelta,
473
+ *,
474
+ beatmap: Beatmap = None,
475
+ add_snap: bool = True,
476
+ has_time: bool = True,
477
+ pos: npt.NDArray = None,
478
+ last_pos: npt.NDArray = None,
479
+ new_combo: bool = False,
480
+ hitsound_ref_times: list[timedelta] = None,
481
+ hitsounds: list[int] = None,
482
+ additions: list[str] = None,
483
+ scroll_speed: Optional[float] = None,
484
+ ) -> npt.NDArray:
485
+ """Add a group of events to the event list."""
486
+ group = Group(
487
+ event_type=event_type,
488
+ time=int(time.total_seconds() * 1000 + 1e-5)
489
+ )
490
+
491
+ if has_time:
492
+ group.has_time = True
493
+ group.snapping = self._get_snapping(time, beatmap, add_snap)
494
+ if pos is not None:
495
+ if beatmap.mode in [0, 2]:
496
+ x, y, dist, last_pos = self._get_position(pos, last_pos)
497
+ group.x = x
498
+ group.y = y
499
+ group.distance = dist
500
+ elif beatmap.mode == 3:
501
+ group.column = self._get_mania_column(pos, int(beatmap.circle_size))
502
+ if new_combo and beatmap.mode in [0, 2]:
503
+ group.new_combo = True
504
+ if scroll_speed is not None:
505
+ group.scroll_speed = scroll_speed
506
+ if hitsound_ref_times is not None and self.add_hitsounds:
507
+ for i, ref_time in enumerate(hitsound_ref_times):
508
+ hitsound, sample_set, addition_set, volume = self._get_hitsounds(ref_time, hitsounds[i], additions[i], beatmap)
509
+ group.hitsounds.append(hitsound)
510
+ group.samplesets.append(sample_set)
511
+ group.additions.append(addition_set)
512
+ group.volumes.append(volume)
513
+
514
+ groups.append(group)
515
+
516
+ return last_pos
517
+
518
+ def _parse_circle(self, circle: Circle, groups: list[Group], last_pos: npt.NDArray, beatmap: Beatmap) -> npt.NDArray:
519
+ """Parse a circle hit object.
520
+
521
+ Args:
522
+ circle: Circle object.
523
+ groups: List of groups to add to.
524
+ last_pos: Last position of the hit objects.
525
+
526
+ Returns:
527
+ pos: Position of the circle.
528
+ """
529
+ return self._add_group(
530
+ EventType.CIRCLE,
531
+ groups,
532
+ time=circle.time,
533
+ beatmap=beatmap,
534
+ pos=np.array(circle.position),
535
+ last_pos=last_pos,
536
+ new_combo=circle.new_combo,
537
+ hitsound_ref_times=[circle.time],
538
+ hitsounds=[circle.hitsound],
539
+ additions=[circle.addition],
540
+ scroll_speed=self.scroll_speed_at(circle.time, beatmap) if beatmap.mode == 1 else None,
541
+ )
542
+
543
+ def _parse_slider(self, slider: Slider, groups: list[Group], last_pos: npt.NDArray, beatmap: Beatmap) -> npt.NDArray:
544
+ """Parse a slider hit object.
545
+
546
+ Args:
547
+ slider: Slider object.
548
+ groups: List of groups to add to.
549
+ last_pos: Last position of the hit objects.
550
+
551
+ Returns:
552
+ pos: Last position of the slider.
553
+ """
554
+ # Ignore sliders which are too big
555
+ if len(slider.curve.points) >= 100:
556
+ return last_pos
557
+
558
+ last_pos = self._add_group(
559
+ EventType.SLIDER_HEAD,
560
+ groups,
561
+ time=slider.time,
562
+ beatmap=beatmap,
563
+ pos=np.array(slider.position),
564
+ last_pos=last_pos,
565
+ new_combo=slider.new_combo,
566
+ hitsound_ref_times=[slider.time],
567
+ hitsounds=[slider.edge_sounds[0] if len(slider.edge_sounds) > 0 else 0],
568
+ additions=[slider.edge_additions[0] if len(slider.edge_additions) > 0 else '0:0'],
569
+ scroll_speed=self.scroll_speed_at(slider.time, beatmap) if self.add_sv else None,
570
+ )
571
+
572
+ duration: timedelta = (slider.end_time - slider.time) / slider.repeat
573
+ control_point_count = len(slider.curve.points)
574
+
575
+ def append_control_points(event_type: EventType, last_pos: npt.NDArray = last_pos) -> npt.NDArray:
576
+ for i in range(1, control_point_count - 1):
577
+ last_pos = add_anchor(event_type, i, last_pos)
578
+
579
+ return last_pos
580
+
581
+ def add_anchor(event_type: EventType, i: int, last_pos: npt.NDArray) -> npt.NDArray:
582
+ return self._add_group(
583
+ event_type,
584
+ groups,
585
+ time=slider.time + i / (control_point_count - 1) * duration if self.slider_version == 1 else slider.time,
586
+ beatmap=beatmap,
587
+ has_time=False,
588
+ pos=np.array(slider.curve.points[i]),
589
+ last_pos=last_pos,
590
+ )
591
+
592
+ if isinstance(slider.curve, Linear):
593
+ last_pos = append_control_points(EventType.RED_ANCHOR, last_pos)
594
+ elif isinstance(slider.curve, Catmull):
595
+ last_pos = append_control_points(EventType.CATMULL_ANCHOR, last_pos)
596
+ elif isinstance(slider.curve, Perfect):
597
+ last_pos = append_control_points(EventType.PERFECT_ANCHOR, last_pos)
598
+ elif isinstance(slider.curve, MultiBezier):
599
+ for i in range(1, control_point_count - 1):
600
+ if slider.curve.points[i] == slider.curve.points[i + 1]:
601
+ last_pos = add_anchor(EventType.RED_ANCHOR, i, last_pos)
602
+ elif slider.curve.points[i] != slider.curve.points[i - 1]:
603
+ last_pos = add_anchor(EventType.BEZIER_ANCHOR, i, last_pos)
604
+
605
+ if self.slider_version == 2:
606
+ # Add last control point without time
607
+ last_pos = self._add_group(
608
+ EventType.LAST_ANCHOR,
609
+ groups,
610
+ time=slider.time,
611
+ beatmap=beatmap,
612
+ has_time=False,
613
+ pos=np.array(slider.curve.points[-1]),
614
+ last_pos=last_pos,
615
+ )
616
+
617
+ # Add body hitsounds and remaining edge hitsounds
618
+ last_pos = self._add_group(
619
+ EventType.SLIDER_END,
620
+ groups,
621
+ time=slider.time + duration,
622
+ beatmap=beatmap,
623
+ pos=np.array(slider.curve.points[-1]) if self.slider_version == 1 else None,
624
+ last_pos=last_pos,
625
+ hitsound_ref_times=[slider.time + timedelta(milliseconds=1)] + [slider.time + i * duration for i in range(1, slider.repeat)],
626
+ hitsounds=[slider.hitsound] + [slider.edge_sounds[i] if len(slider.edge_sounds) > i else 0 for i in range(1, slider.repeat)],
627
+ additions=[slider.addition] + [slider.edge_additions[i] if len(slider.edge_additions) > i else '0:0' for i in range(1, slider.repeat)],
628
+ )
629
+
630
+ return self._add_group(
631
+ EventType.REPEAT_END,
632
+ groups,
633
+ time=slider.end_time,
634
+ beatmap=beatmap,
635
+ pos=np.array(slider.curve(1)),
636
+ last_pos=last_pos,
637
+ hitsound_ref_times=[slider.end_time],
638
+ hitsounds=[slider.edge_sounds[-1] if len(slider.edge_sounds) > 0 else 0],
639
+ additions=[slider.edge_additions[-1] if len(slider.edge_additions) > 0 else '0:0'],
640
+ )
641
+
642
+ def _parse_spinner(self, spinner: Spinner, groups: list[Group], beatmap: Beatmap) -> npt.NDArray:
643
+ """Parse a spinner hit object.
644
+
645
+ Args:
646
+ spinner: Spinner object.
647
+ groups: List of groups to add to.
648
+
649
+ Returns:
650
+ pos: Last position of the spinner.
651
+ """
652
+ self._add_group(
653
+ EventType.SPINNER,
654
+ groups,
655
+ time=spinner.time,
656
+ beatmap=beatmap,
657
+ )
658
+
659
+ self._add_group(
660
+ EventType.SPINNER_END,
661
+ groups,
662
+ time=spinner.end_time,
663
+ beatmap=beatmap,
664
+ hitsound_ref_times=[spinner.end_time],
665
+ hitsounds=[spinner.hitsound],
666
+ additions=[spinner.addition],
667
+ )
668
+
669
+ return np.array((256, 192))
670
+
671
+ def _parse_hold_note(self, hold_note: HoldNote, groups: list[Group], beatmap: Beatmap) -> npt.NDArray:
672
+ """Parse a hold note hit object.
673
+
674
+ Args:
675
+ hold note: Hold note object.
676
+ groups: List of groups to add to.
677
+
678
+ Returns:
679
+ pos: Last position of the spinner.
680
+ """
681
+ pos = np.array(hold_note.position)
682
+
683
+ self._add_group(
684
+ EventType.HOLD_NOTE,
685
+ groups,
686
+ time=hold_note.time,
687
+ beatmap=beatmap,
688
+ pos=pos,
689
+ hitsound_ref_times=[hold_note.time],
690
+ hitsounds=[hold_note.hitsound],
691
+ additions=[hold_note.addition],
692
+ )
693
+
694
+ self._add_group(
695
+ EventType.HOLD_NOTE_END,
696
+ groups,
697
+ time=hold_note.end_time,
698
+ beatmap=beatmap,
699
+ pos=pos,
700
+ )
701
+
702
+ return pos
703
+
704
+ def _parse_drumroll(self, slider: Slider, groups: list[Group], beatmap: Beatmap):
705
+ """Parse a drumroll hit object.
706
+
707
+ Args:
708
+ slider: Slider object.
709
+ groups: List of groups to add to.
710
+ """
711
+ self._add_group(
712
+ EventType.DRUMROLL,
713
+ groups,
714
+ time=slider.time,
715
+ beatmap=beatmap,
716
+ hitsound_ref_times=[slider.time],
717
+ hitsounds=[slider.hitsound], # Edge hitsounds are not supported in drumrolls
718
+ additions=[slider.addition],
719
+ scroll_speed=self.scroll_speed_at(slider.time, beatmap),
720
+ )
721
+
722
+ self._add_group(
723
+ EventType.DRUMROLL_END,
724
+ groups,
725
+ time=slider.end_time,
726
+ beatmap=beatmap,
727
+ )
728
+
729
+ def _parse_denden(self, spinner: Spinner, groups: list[Group], beatmap: Beatmap):
730
+ """Parse a denden hit object.
731
+
732
+ Args:
733
+ spinner: Spinner object.
734
+ groups: List of groups to add to.
735
+ """
736
+ self._add_group(
737
+ EventType.DENDEN,
738
+ groups,
739
+ time=spinner.time,
740
+ beatmap=beatmap,
741
+ hitsound_ref_times=[spinner.time],
742
+ hitsounds=[spinner.hitsound],
743
+ additions=[spinner.addition],
744
+ scroll_speed=self.scroll_speed_at(spinner.time, beatmap),
745
+ )
746
+
747
+ self._add_group(
748
+ EventType.DENDEN_END,
749
+ groups,
750
+ time=spinner.end_time,
751
+ beatmap=beatmap,
752
+ )
753
+
754
+
755
+ AutoFeatureExtractor.register(CM3PConfig, CM3PBeatmapParser)
756
+
757
+ __all__ = ["CM3PBeatmapParser", "EventType", "Group", "load_beatmap", "get_song_length", "EVENT_TYPES_WITH_NEW_COMBO"]
processing_cm3p.py ADDED
@@ -0,0 +1,835 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import itertools
3
+ import math
4
+ import os
5
+ from os import PathLike
6
+ from pathlib import Path
7
+ from typing import Optional, Union, IO, TypedDict
8
+
9
+ import numpy as np
10
+ from huggingface_hub.errors import HfHubHTTPError
11
+ from pandas import Series
12
+ from slider import Beatmap, HoldNote
13
+ from transformers import WhisperFeatureExtractor, AutoProcessor, BatchEncoding
14
+ from transformers.dynamic_module_utils import custom_object_save
15
+ from transformers.tokenization_utils_base import TruncationStrategy, PreTrainedTokenizerBase
16
+ from transformers.utils import is_torch_available, PaddingStrategy, PROCESSOR_NAME, logging
17
+ from huggingface_hub import CommitOperationAdd, create_branch, create_commit
18
+
19
+ from .configuration_cm3p import CM3PConfig
20
+ from .parsing_cm3p import CM3PBeatmapParser, load_beatmap, get_song_length
21
+ from .tokenization_cm3p import CM3PBeatmapTokenizer, CM3PMetadataTokenizer, CM3PMetadata, merge_metadata_dicts
22
+
23
+ if is_torch_available():
24
+ import torch
25
+
26
+ from transformers.audio_utils import AudioInput, make_list_of_audio, load_audio
27
+ from transformers.feature_extraction_utils import BatchFeature
28
+ from transformers.processing_utils import AudioKwargs, ProcessorMixin, CommonKwargs
29
+
30
+ logger = logging.get_logger(__name__)
31
+
32
+
33
+ def get_hold_note_ratio(beatmap: Beatmap) -> Optional[float]:
34
+ notes = beatmap.hit_objects(stacking=False)
35
+
36
+ if len(notes) == 0:
37
+ return None
38
+
39
+ hold_note_count = 0
40
+ for note in notes:
41
+ if isinstance(note, HoldNote):
42
+ hold_note_count += 1
43
+ return hold_note_count / len(notes)
44
+
45
+
46
+ def get_scroll_speed_ratio(beatmap: Beatmap) -> Optional[float]:
47
+ # Number of scroll speed changes divided by number of distinct hit object times
48
+ notes = beatmap.hit_objects(stacking=False)
49
+
50
+ if len(notes) == 0:
51
+ return None
52
+
53
+ last_time = -1
54
+ num_note_times = 0
55
+ for note in notes:
56
+ if note.time != last_time:
57
+ num_note_times += 1
58
+ last_time = note.time
59
+ last_scroll_speed = -1
60
+ num_scroll_speed_changes = 0
61
+ for timing_point in beatmap.timing_points:
62
+ if timing_point.parent is None:
63
+ last_scroll_speed = 1
64
+ else:
65
+ scroll_speed = -100 / timing_point.ms_per_beat
66
+ if scroll_speed != last_scroll_speed and last_scroll_speed != -1:
67
+ num_scroll_speed_changes += 1
68
+ last_scroll_speed = scroll_speed
69
+ return num_scroll_speed_changes / num_note_times
70
+
71
+
72
+ def get_hitsounded_status(beatmap: Beatmap) -> bool:
73
+ notes = beatmap.hit_objects(stacking=False)
74
+ for note in notes:
75
+ if note.hitsound != 0:
76
+ return True
77
+ return False
78
+
79
+
80
+ def get_difficulty(beatmap_metadata: Series, speed: float = 1.0) -> float:
81
+ # StarRating is an array that gives the difficulty for the speeds:
82
+ # 0.5, 0.75, 1.0, 1.25, 1.5, 1.75, 2.0
83
+ # Linearly interpolate between the two closest speeds
84
+ star_ratings = beatmap_metadata["StarRating"]
85
+ speed_ratios = [0.5, 0.75, 1.0, 1.25, 1.5, 1.75, 2.0]
86
+ return np.interp(speed, speed_ratios, star_ratings)
87
+
88
+
89
+ def get_metadata(
90
+ beatmap_metadata: Series = None,
91
+ beatmap: Beatmap = None,
92
+ audio_samples: np.ndarray = None,
93
+ sampling_rate: int = None,
94
+ speed: float = 1.0,
95
+ song_position: Optional[float] = None,
96
+ ) -> CM3PMetadata:
97
+ mode = beatmap.mode if beatmap is not None else beatmap_metadata["ModeInt"] if beatmap_metadata is not None else None
98
+ circle_size = beatmap.circle_size if beatmap is not None else beatmap_metadata["Cs"] if beatmap_metadata is not None else None
99
+ song_length = get_song_length(audio_samples, sampling_rate, beatmap)
100
+ return CM3PMetadata(
101
+ difficulty=get_difficulty(beatmap_metadata, speed) if beatmap_metadata is not None else None,
102
+ year=beatmap_metadata["SubmittedDate"].year if beatmap_metadata is not None else None,
103
+ mode=mode,
104
+ status=beatmap_metadata["Status"] if beatmap_metadata is not None else None,
105
+ mapper=beatmap_metadata["UserId"] if beatmap_metadata is not None else None,
106
+ cs=circle_size if mode in [0, 2] is not None else None,
107
+ hitsounded=get_hitsounded_status(beatmap) if beatmap is not None else None,
108
+ song_length=song_length,
109
+ song_position=song_position,
110
+ global_sv=beatmap.slider_multiplier if mode in [0, 2] and beatmap is not None else None,
111
+ mania_keycount=int(circle_size) if mode == 3 and beatmap is not None else None,
112
+ hold_note_ratio=get_hold_note_ratio(beatmap) if mode == 3 and beatmap is not None else None,
113
+ scroll_speed_ratio=get_scroll_speed_ratio(beatmap) if mode in [1, 3] and beatmap is not None else None,
114
+ tags=beatmap_metadata["TopTagIds"].tolist() if beatmap_metadata is not None else None,
115
+ )
116
+
117
+
118
+ class CM3PTokenizerKwargs(TypedDict, total=False):
119
+ add_special_tokens: Optional[bool]
120
+ padding: Union[bool, str, PaddingStrategy]
121
+ truncation: Union[bool, str, TruncationStrategy]
122
+ max_length: Optional[int]
123
+ pad_to_multiple_of: Optional[int]
124
+ return_token_type_ids: Optional[bool]
125
+ return_attention_mask: Optional[bool]
126
+ return_overflowing_tokens: Optional[bool]
127
+ return_special_tokens_mask: Optional[bool]
128
+ return_offsets_mapping: Optional[bool]
129
+ return_length: Optional[bool]
130
+ verbose: Optional[bool]
131
+ padding_side: Optional[str]
132
+ return_mm_token_type_ids: Optional[bool]
133
+
134
+
135
+ class CM3PBeatmapKwargs(CM3PTokenizerKwargs, total=False):
136
+ window_length_sec: float
137
+ window_stride_sec: float
138
+
139
+
140
+ class CM3PAudioKwargs(AudioKwargs, total=False):
141
+ max_source_positions: Optional[int]
142
+ hop_length: Optional[int]
143
+ window_size: Optional[int]
144
+ audio_length_per_tok: Optional[int]
145
+ device: Optional[str]
146
+
147
+
148
+ # noinspection PyTypedDict
149
+ class CM3PProcessorKwargs(CommonKwargs, CM3PBeatmapKwargs, CM3PTokenizerKwargs, CM3PAudioKwargs, total=False):
150
+ _defaults = {
151
+ "beatmap_kwargs": {
152
+ "max_length": 8000,
153
+ "padding": PaddingStrategy.LONGEST,
154
+ "truncation": TruncationStrategy.LONGEST_FIRST,
155
+ "window_length_sec": 30.0,
156
+ "window_stride_sec": 30.0,
157
+ "min_window_length_sec": 1.0,
158
+ },
159
+ "metadata_kwargs": {
160
+ "max_length": 128,
161
+ "padding": PaddingStrategy.LONGEST,
162
+ "truncation": TruncationStrategy.LONGEST_FIRST,
163
+ },
164
+ "audio_kwargs": {
165
+ "sampling_rate": 16000,
166
+ "padding": True,
167
+ "truncation": False,
168
+ "pad_to_multiple_of": 480000,
169
+ "max_source_positions": 3000,
170
+ "hop_length": 160,
171
+ "window_size": 400,
172
+ "audio_length_per_tok": 8,
173
+ "device": "cpu",
174
+ },
175
+ "common_kwargs": {
176
+ "return_tensors": "pt",
177
+ },
178
+ }
179
+
180
+ common_kwargs: CommonKwargs = {
181
+ **CommonKwargs.__annotations__,
182
+ }
183
+ beatmap_kwargs: CM3PBeatmapKwargs = {
184
+ **CM3PTokenizerKwargs.__annotations__,
185
+ }
186
+ metadata_kwargs: CM3PTokenizerKwargs = {
187
+ **CM3PTokenizerKwargs.__annotations__,
188
+ }
189
+ audio_kwargs: CM3PAudioKwargs = {
190
+ **CM3PAudioKwargs.__annotations__,
191
+ }
192
+
193
+
194
+ class CM3PProcessor(ProcessorMixin):
195
+ r"""
196
+ Constructs a CM3P processor which wraps [`WhisperFeatureExtractor`] and
197
+ [`MistralCommonTokenizer`] into a single processor that inherits both the audio feature extraction and
198
+ tokenizer functionalities.
199
+
200
+ Args:
201
+ audio_feature_extractor ([`WhisperFeatureExtractor`]):
202
+ The feature extractor is a required input.
203
+ beatmap_parser ([`CM3PBeatmapParser`]):
204
+ The beatmap parser is a required input.
205
+ beatmap_tokenizer ([`CM3PBeatmapTokenizer`]):
206
+ The beatmap tokenizer is a required input.
207
+ metadata_tokenizer ([`CM3PMetadataTokenizer`]):
208
+ The metadata tokenizer is a required input.
209
+ default_kwargs (`CM3PProcessorKwargs`, *optional*):
210
+ Default keyword arguments for the processor. If not provided, the processor will use its own defaults
211
+ """
212
+
213
+ attributes = ["audio_feature_extractor", "beatmap_parser", "beatmap_tokenizer", "metadata_tokenizer"]
214
+ audio_feature_extractor_class = "WhisperFeatureExtractor"
215
+ beatmap_parser_class = "CM3PBeatmapParser"
216
+ beatmap_tokenizer_class = "CM3PBeatmapTokenizer"
217
+ metadata_tokenizer_class = "CM3PMetadataTokenizer"
218
+
219
+ def __init__(
220
+ self,
221
+ audio_feature_extractor: WhisperFeatureExtractor,
222
+ beatmap_parser: CM3PBeatmapParser,
223
+ beatmap_tokenizer: CM3PBeatmapTokenizer,
224
+ metadata_tokenizer: CM3PMetadataTokenizer,
225
+ default_kwargs: Optional[CM3PProcessorKwargs] = None,
226
+ ):
227
+ self.audio_feature_extractor = audio_feature_extractor
228
+ self.beatmap_parser = beatmap_parser
229
+ self.beatmap_tokenizer = beatmap_tokenizer
230
+ self.metadata_tokenizer = metadata_tokenizer
231
+ self.audio_token = beatmap_tokenizer.audio_token
232
+
233
+ # noinspection PyProtectedMember
234
+ self.default_kwargs = default_kwargs or copy.deepcopy(CM3PProcessorKwargs._defaults)
235
+
236
+ super().__init__(audio_feature_extractor, beatmap_parser, beatmap_tokenizer, metadata_tokenizer)
237
+
238
+ def _pad_audio(
239
+ self,
240
+ audio_array: np.ndarray,
241
+ window_size: int = 400,
242
+ pad_to_multiple_of: Optional[int] = 480000,
243
+ **_,
244
+ ) -> np.ndarray:
245
+ r"""Pad the audio array to the desired length.
246
+
247
+ Args:
248
+ audio_array: Audio data as a numpy array.
249
+ sampling_rate: Sampling rate of the audio.
250
+
251
+ Returns:
252
+ Padded audio array.
253
+ """
254
+ if pad_to_multiple_of:
255
+ next_multiple_of_chunk_frames = math.ceil(audio_array.shape[-1] / pad_to_multiple_of) * pad_to_multiple_of
256
+ audio_array = np.pad(audio_array, (0, next_multiple_of_chunk_frames - audio_array.shape[-1]))
257
+ elif audio_array.shape[-1] < window_size:
258
+ # minimum length for audios is at least one spectrogram frame
259
+ audio_array = np.pad(audio_array, (0, window_size - audio_array.shape[-1]))
260
+
261
+ return audio_array
262
+
263
+ def _encode_audio(
264
+ self,
265
+ audio: np.ndarray,
266
+ hop_length: int = 160,
267
+ audio_length_per_tok: int = 8,
268
+ **kwargs,
269
+ ) -> tuple[np.ndarray, int]:
270
+ audio = self._pad_audio(audio, **kwargs)
271
+ signal_length = audio.shape[0]
272
+
273
+ # for spectrogram-based models, the waveform is downsampled by the hop_length when computing the log-mel
274
+ if signal_length % hop_length != 0:
275
+ signal_length = math.ceil(signal_length / hop_length - 1)
276
+ else:
277
+ signal_length = signal_length // hop_length
278
+
279
+ num_audio_tokens = math.ceil(signal_length / audio_length_per_tok)
280
+
281
+ return audio, num_audio_tokens
282
+
283
+ def _retrieve_input_features(self, audio, max_source_positions, **kwargs) -> Union[torch.Tensor, np.ndarray]:
284
+ """
285
+ Handles specific logic of CM3P expected input features: audio arrays should be padded to next multiple of 480000 (duration is a multiple of 30s), see CM3PProcessorKwargs' default audio_kwargs.
286
+ Then mel input features are extracted and stacked along batch dimension, splitting into chunks of max_source_positions.
287
+ """
288
+ return_tensors = kwargs.get("return_tensors", "pt")
289
+ input_features_list = []
290
+ for audio_array in audio:
291
+ audio_inputs = self.audio_feature_extractor(audio_array, **kwargs)
292
+
293
+ # let's split into chunks of max_source_positions, and then stack them along batch dimension
294
+ input_features = audio_inputs["input_features"].reshape(
295
+ self.audio_feature_extractor.feature_size, -1, max_source_positions
296
+ )
297
+
298
+ input_features_list.append(input_features.swapaxes(0, 1))
299
+
300
+ if return_tensors == "pt":
301
+ return torch.cat(input_features_list)
302
+
303
+ return np.concatenate(input_features_list)
304
+
305
+ def _load_audio(
306
+ self,
307
+ sampling_rate: int,
308
+ audio: Union[str, list[str], Path, list[Path], AudioInput],
309
+ audio_sampling_rate: Optional[Union[int, list[int]]] = None,
310
+ speed: float = 1.0,
311
+ ) -> list[np.ndarray]:
312
+ """
313
+ Helper method to load audio from various formats and return a list of audio buffers.
314
+ """
315
+
316
+ # convert Path objects to str
317
+ if isinstance(audio, Path):
318
+ audio = str(audio)
319
+ if isinstance(audio, list) and all(isinstance(el, Path) for el in audio):
320
+ audio = [str(el) for el in audio]
321
+
322
+ # validate audio input
323
+ is_str = isinstance(audio, str)
324
+ is_list_of_str = isinstance(audio, list) and all(isinstance(el, str) for el in audio)
325
+ is_list_of_audio = not (is_str or is_list_of_str)
326
+
327
+ if is_list_of_audio:
328
+ if audio_sampling_rate is None:
329
+ # noinspection PyUnresolvedReferences
330
+ logger.warning_once(
331
+ f"You've provided audio without specifying the sampling rate. It will be assumed to be {sampling_rate}, which can result in silent errors."
332
+ )
333
+ audio_sampling_rate = sampling_rate
334
+
335
+ if is_str:
336
+ audio = [load_audio(audio, sampling_rate=int(sampling_rate // speed))]
337
+ audio_sampling_rate = sampling_rate
338
+ elif is_list_of_str:
339
+ audio = [load_audio(el, sampling_rate=int(sampling_rate // speed)) for el in audio]
340
+ audio_sampling_rate = sampling_rate
341
+
342
+ audio = make_list_of_audio(audio)
343
+
344
+ if isinstance(audio_sampling_rate, int):
345
+ audio_sampling_rate = [audio_sampling_rate] * len(audio)
346
+
347
+ audio_buffers = []
348
+ for array, s in zip(audio, audio_sampling_rate):
349
+ array = np.asarray(array)
350
+ # Convert to mono if needed
351
+ if array.ndim == 2:
352
+ array = array.mean(axis=1)
353
+ # Resample if the sampling rate is different from the expected one
354
+ if s != sampling_rate:
355
+ import soxr
356
+ array = soxr.resample(array, s, sampling_rate, quality="HQ")
357
+ audio_buffers.append(array)
358
+
359
+ return audio_buffers
360
+
361
+ # noinspection PyTypedDict
362
+ def _merge_kwargs(self, **kwargs) -> CM3PProcessorKwargs:
363
+ output_kwargs = CM3PProcessorKwargs()
364
+ nested_modalities = ["beatmap_kwargs", "metadata_kwargs", "audio_kwargs", "common_kwargs"]
365
+ possible_modality_keywords = {"beatmap", "metadata", "audio"}
366
+ used_keys = set()
367
+
368
+ # pass defaults to output dictionary
369
+ output_kwargs.update(copy.deepcopy(self.default_kwargs))
370
+
371
+ # update modality kwargs with passed kwargs
372
+ non_modality_kwargs = set(kwargs) - set(output_kwargs)
373
+ for modality, output_kwarg in output_kwargs.items():
374
+ for modality_key in CM3PProcessorKwargs.__annotations__[modality].__annotations__:
375
+ # check if we received a structured kwarg dict or not to handle it correctly
376
+ if modality in kwargs:
377
+ kwarg_value = kwargs[modality].pop(modality_key, "__empty__")
378
+ # check if this key was passed as a flat kwarg.
379
+ if kwarg_value != "__empty__" and modality_key in non_modality_kwargs:
380
+ raise ValueError(
381
+ f"Keyword argument {modality_key} was passed two times:\n"
382
+ f"in a dictionary for {modality} and as a **kwarg."
383
+ )
384
+ elif modality_key in kwargs:
385
+ # we get a modality_key instead of popping it because modality-specific processors
386
+ # can have overlapping kwargs
387
+ kwarg_value = kwargs.get(modality_key, "__empty__")
388
+ else:
389
+ kwarg_value = "__empty__"
390
+ if not isinstance(kwarg_value, str) or kwarg_value != "__empty__":
391
+ output_kwarg[modality_key] = kwarg_value
392
+ used_keys.add(modality_key)
393
+
394
+ # Determine if kwargs is a flat dictionary or contains nested dictionaries
395
+ if any(key in nested_modalities for key in kwargs):
396
+ # kwargs is dictionary-based, and some keys match modality names
397
+ for modality, subdict in kwargs.items():
398
+ if modality in nested_modalities:
399
+ for subkey, subvalue in subdict.items():
400
+ if subkey not in used_keys:
401
+ output_kwargs[modality][subkey] = subvalue
402
+ used_keys.add(subkey)
403
+ else:
404
+ # kwargs is a flat dictionary
405
+ for key, kwarg in kwargs.items():
406
+ if key not in used_keys:
407
+ if key in CM3PProcessorKwargs.__annotations__["common_kwargs"].__annotations__:
408
+ output_kwargs["common_kwargs"][key] = kwarg
409
+ elif key not in possible_modality_keywords:
410
+ # noinspection PyUnresolvedReferences
411
+ logger.warning_once(
412
+ f"Keyword argument `{key}` is not a valid argument for this processor and will be ignored."
413
+ )
414
+
415
+ # all modality-specific kwargs are updated with common kwargs
416
+ for kwarg in output_kwargs.values():
417
+ kwarg.update(output_kwargs["common_kwargs"])
418
+ return output_kwargs
419
+
420
+ def __call__(
421
+ self,
422
+ metadata: Optional[Union[CM3PMetadata, list[CM3PMetadata]]] = None,
423
+ beatmap: Optional[Union[str, list[str], PathLike, list[PathLike], IO[str], list[IO[str]], Beatmap, list[Beatmap]]] = None,
424
+ audio: Optional[Union[str, list[str], Path, list[Path], AudioInput]] = None,
425
+ audio_sampling_rate: Optional[Union[int, list[int]]] = None,
426
+ speed: float = 1.0,
427
+ multiply_metadata: bool = False,
428
+ populate_metadata: bool = False,
429
+ metadata_dropout_prob: float = 0.0,
430
+ metadata_variations: int = 1,
431
+ **kwargs,
432
+ ):
433
+ output_kwargs = self._merge_kwargs(**kwargs)
434
+
435
+ beatmap_kwargs: CM3PTokenizerKwargs = output_kwargs["beatmap_kwargs"]
436
+ metadata_kwargs: CM3PTokenizerKwargs = output_kwargs["metadata_kwargs"]
437
+ audio_kwargs: CM3PAudioKwargs = output_kwargs["audio_kwargs"]
438
+ common_kwargs: CommonKwargs = output_kwargs["common_kwargs"]
439
+
440
+ window_length_sec = beatmap_kwargs.pop("window_length_sec")
441
+ window_stride_sec = beatmap_kwargs.pop("window_stride_sec")
442
+ min_window_length_sec = beatmap_kwargs.pop("min_window_length_sec", 1.0)
443
+ max_length = beatmap_kwargs.get("max_length", 8000)
444
+ metadata_max_length = metadata_kwargs.get("max_length", 128)
445
+ sampling_rate = audio_kwargs["sampling_rate"]
446
+ max_source_positions = audio_kwargs.get("max_source_positions", 3000)
447
+ audio_kwargs["padding"] = False
448
+ return_tensors = common_kwargs["return_tensors"]
449
+
450
+ metadata_encoding, beatmap_encoding, num_audio_tokens, metadata_variation_classes = None, None, None, None
451
+
452
+ if return_tensors is not None and return_tensors != "pt":
453
+ raise ValueError(f"{self.__class__.__name__} only supports `return_tensors='pt'` or `return_tensors=None`.")
454
+
455
+ if metadata is None and beatmap is None:
456
+ raise ValueError("You have to specify either metadata or beatmap. Both cannot be none.")
457
+
458
+ if audio is not None:
459
+ audio = self._load_audio(
460
+ sampling_rate,
461
+ audio,
462
+ audio_sampling_rate=audio_sampling_rate,
463
+ )
464
+
465
+ if beatmap is not None:
466
+ if not isinstance(beatmap, list):
467
+ beatmap = [beatmap]
468
+
469
+ if audio is not None:
470
+ if len(beatmap) != len(audio):
471
+ raise ValueError(
472
+ f"The number of beatmaps ({len(beatmap)}) must match the number of audio ({len(audio)})"
473
+ )
474
+ else:
475
+ audio = [None] * len(beatmap)
476
+
477
+ if multiply_metadata or populate_metadata and metadata is not None:
478
+ matched_metadata = metadata
479
+ if not isinstance(matched_metadata, list):
480
+ matched_metadata = [matched_metadata]
481
+ if (multiply_metadata or populate_metadata) and len(matched_metadata) != len(beatmap):
482
+ raise ValueError(
483
+ f"The number of metadata entries ({len(matched_metadata)}) must match the number of beatmaps ({len(beatmap)})"
484
+ "` if multiply_metadata` or `populate_metadata` is set to True."
485
+ )
486
+ else:
487
+ matched_metadata = [CM3PMetadata()] * len(beatmap) if populate_metadata else [None] * len(beatmap)
488
+
489
+ new_metadata = []
490
+ batch_start_ms = []
491
+ batch_groups = []
492
+ batch_audio = []
493
+ batch_num_audio_tokens = []
494
+ for b, m, audio_array in zip(beatmap, matched_metadata, audio):
495
+ b: Beatmap = load_beatmap(b)
496
+ song_length = get_song_length(audio_array, sampling_rate, b)
497
+ beatmap_groups = self.beatmap_parser.parse_beatmap(b, speed=speed, song_length=song_length)
498
+
499
+ def add_metadata(song_position: Optional[float] = None):
500
+ if populate_metadata:
501
+ new_metadata.append(merge_metadata_dicts(m, get_metadata(
502
+ beatmap=b,
503
+ audio_samples=audio_array,
504
+ sampling_rate=sampling_rate,
505
+ speed=speed,
506
+ song_position=song_position,
507
+ )))
508
+ else:
509
+ new_metadata.append(m)
510
+
511
+ if not multiply_metadata:
512
+ add_metadata()
513
+
514
+ # Loop through with sliding window
515
+ groups_search_index = 0
516
+ for start_sec in np.arange(0, song_length - min_window_length_sec, window_stride_sec):
517
+ end_sec = start_sec + window_length_sec
518
+
519
+ if audio_array is not None:
520
+ # Slice audio waveform
521
+ start_frame = int(start_sec * sampling_rate)
522
+ end_frame = int(end_sec * sampling_rate)
523
+ audio_slice = audio_array[start_frame:end_frame]
524
+ # Pad the audio array and calculate the number of audio tokens
525
+ audio_slice, num_audio_tokens = self._encode_audio(audio_slice, **audio_kwargs)
526
+ else:
527
+ audio_slice = None
528
+ num_audio_tokens = 0
529
+
530
+ # Find groups that fall within the current window
531
+ # Groups are sorted by time, so we can use a simple linear search from the last index
532
+ start_ms = start_sec * 1000
533
+ end_ms = end_sec * 1000
534
+ next_start_ms = (start_sec + window_stride_sec) * 1000
535
+ window_groups = []
536
+ for group in itertools.islice(beatmap_groups, groups_search_index, None):
537
+ if group.time < next_start_ms:
538
+ groups_search_index += 1
539
+
540
+ if group.time < start_ms:
541
+ continue
542
+ elif group.time < end_ms:
543
+ window_groups.append(group)
544
+ else:
545
+ break
546
+
547
+ batch_start_ms.append(start_ms)
548
+ batch_groups.append(window_groups)
549
+ batch_audio.append(audio_slice)
550
+ batch_num_audio_tokens.append(num_audio_tokens)
551
+
552
+ if multiply_metadata:
553
+ add_metadata(start_sec / song_length)
554
+
555
+ if populate_metadata or multiply_metadata:
556
+ metadata = new_metadata
557
+
558
+ if len(batch_groups) > 0:
559
+ beatmap_encoding = self.beatmap_tokenizer(
560
+ groups=batch_groups,
561
+ window_start_ms=batch_start_ms,
562
+ num_audio_tokens=batch_num_audio_tokens,
563
+ **beatmap_kwargs,
564
+ )
565
+
566
+ if audio is not None:
567
+ data = dict(beatmap_encoding)
568
+ data["input_features"] = self._retrieve_input_features(batch_audio, **audio_kwargs)
569
+ beatmap_encoding = BatchFeature(data, tensor_type=return_tensors)
570
+ else:
571
+ # No windows with hit objects were found, return empty encoding
572
+ logger.warning("Warning: No windows with hit objects were found in the provided beatmap(s). Returning empty encoding.")
573
+ beatmap_encoding = BatchEncoding(
574
+ {
575
+ "input_ids": torch.zeros((0, max_length), dtype=torch.long) if return_tensors == "pt" else [],
576
+ "attention_mask": torch.zeros((0, max_length), dtype=torch.long) if return_tensors == "pt" else [],
577
+ },
578
+ tensor_type=return_tensors,
579
+ )
580
+ if audio is not None:
581
+ data = dict(beatmap_encoding)
582
+ data["input_features"] = torch.zeros((0, self.audio_feature_extractor.feature_size, max_source_positions), dtype=torch.float) if return_tensors == "pt" else []
583
+ beatmap_encoding = BatchFeature(data, tensor_type=return_tensors)
584
+
585
+ if metadata is not None and not (isinstance(metadata, list) and any(m is None for m in metadata)):
586
+ if not isinstance(metadata, list):
587
+ metadata = [metadata]
588
+
589
+ if metadata_dropout_prob > 0.0:
590
+ for m in metadata:
591
+ # Randomly drop out metadata fields
592
+ for key, value in m.items():
593
+ if value is not None and np.random.rand() < metadata_dropout_prob:
594
+ # noinspection PyTypedDict
595
+ m[key] = None
596
+
597
+ if metadata_variations > 1:
598
+ extended_metadata = []
599
+ metadata_variation_classes = []
600
+ for m in metadata:
601
+ m_vars, m_classes = zip(*self.metadata_tokenizer.metadata_variations(m, metadata_variations - 1))
602
+ extended_metadata.append(m)
603
+ extended_metadata.extend(m_vars)
604
+ metadata_variation_classes.append([0] + list(m_classes)) # Class 0 is the original metadata
605
+
606
+ assert len(extended_metadata) == len(metadata) * metadata_variations
607
+ metadata = extended_metadata
608
+
609
+ if len(metadata) > 0:
610
+ metadata_encoding = self.metadata_tokenizer(
611
+ metadata,
612
+ **metadata_kwargs,
613
+ )
614
+ if metadata_variations > 1:
615
+ # Reshape to (batch_size, variations, seq_len)
616
+ for k, v in metadata_encoding.items():
617
+ if return_tensors == "pt":
618
+ v = v.view(len(metadata) // metadata_variations, metadata_variations, -1)
619
+ else:
620
+ v = [v[i:i + metadata_variations] for i in range(0, len(v), metadata_variations)]
621
+ metadata_encoding[k] = v
622
+ if metadata_variation_classes is not None:
623
+ metadata_encoding["metadata_variation_classes"] = torch.tensor(metadata_variation_classes, dtype=torch.long) if return_tensors == "pt" else metadata_variation_classes
624
+ else:
625
+ metadata_encoding = BatchEncoding(
626
+ {
627
+ "input_ids": torch.zeros((0, metadata_max_length), dtype=torch.long) if return_tensors == "pt" else [],
628
+ "attention_mask": torch.zeros((0, metadata_max_length), dtype=torch.long) if return_tensors == "pt" else [],
629
+ },
630
+ tensor_type=return_tensors,
631
+ )
632
+
633
+ if metadata_encoding is not None and beatmap_encoding is not None:
634
+ beatmap_encoding["metadata_ids"] = metadata_encoding["input_ids"]
635
+ beatmap_encoding["metadata_attention_mask"] = metadata_encoding["attention_mask"]
636
+ if "metadata_variation_classes" in metadata_encoding:
637
+ beatmap_encoding["metadata_variation_classes"] = metadata_encoding["metadata_variation_classes"]
638
+ return beatmap_encoding
639
+ elif beatmap_encoding is not None:
640
+ return beatmap_encoding
641
+ else:
642
+ return metadata_encoding
643
+
644
+ def batch_decode(self, *args, **kwargs):
645
+ """
646
+ This method forwards all its arguments to CM3PBeatmapTokenizer's [`~CM3PBeatmapTokenizer.batch_decode`]. Please
647
+ refer to the docstring of this method for more information.
648
+ """
649
+ return self.beatmap_tokenizer.batch_decode(*args, **kwargs)
650
+
651
+ def decode(self, *args, **kwargs):
652
+ """
653
+ This method forwards all its arguments to CM3PBeatmapTokenizer's [`~CM3PBeatmapTokenizer.decode`]. Please refer to
654
+ the docstring of this method for more information.
655
+ """
656
+ return self.beatmap_tokenizer.decode(*args, **kwargs)
657
+
658
+ def save_pretrained(self, save_directory, push_to_hub: bool = False, **kwargs):
659
+ """
660
+ Save processor and its sub-components, with support for AutoProcessor remote code.
661
+
662
+ This is a lightly adapted version of ProcessorMixin.save_pretrained:
663
+ - child attributes are saved into subfolders (audio_feature_extractor/, beatmap_parser/, ...);
664
+ - when self._auto_class is set (via register_for_auto_class), custom_object_save is used
665
+ so that auto_map and dynamic modules are written correctly.
666
+ """
667
+ os.makedirs(save_directory, exist_ok=True)
668
+
669
+ # Handle Hub integration (same as ProcessorMixin / your existing code)
670
+ if push_to_hub:
671
+ commit_message = kwargs.pop("commit_message", None)
672
+ repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1])
673
+ repo_id = self._create_repo(repo_id, **kwargs)
674
+ files_timestamps = self._get_files_timestamps(save_directory)
675
+ else:
676
+ commit_message = None
677
+ repo_id = None
678
+ files_timestamps = None
679
+
680
+ # If we have a custom processor registered for an Auto class,
681
+ # save its code and dependencies as a dynamic module and
682
+ # populate the auto_map field in processor_config.json.
683
+ if self._auto_class is not None:
684
+ attrs = [getattr(self, attribute_name) for attribute_name in self.attributes]
685
+
686
+ # For tokenizers, we pass their init_kwargs; for other objects, we pass the object itself.
687
+ configs = []
688
+ for a in attrs:
689
+ if isinstance(a, PreTrainedTokenizerBase):
690
+ configs.append(a.init_kwargs)
691
+ else:
692
+ configs.append(a)
693
+
694
+ # Include the processor itself so its class is exported.
695
+ configs.append(self)
696
+
697
+ custom_object_save(self, save_directory, config=configs)
698
+
699
+ # Save each sub-component into its own subfolder
700
+ for attribute_name in self.attributes:
701
+ attribute = getattr(self, attribute_name)
702
+
703
+ # Include the processor class in the attribute config so this
704
+ # processor can then be reloaded with the AutoProcessor API.
705
+ if hasattr(attribute, "_set_processor_class"):
706
+ # noinspection PyProtectedMember
707
+ attribute._set_processor_class(self.__class__.__name__)
708
+
709
+ attribute.save_pretrained(os.path.join(save_directory, attribute_name))
710
+
711
+ # Clean up temporary auto_map injected into tokenizers, if any
712
+ if self._auto_class is not None:
713
+ for attribute_name in self.attributes:
714
+ attribute = getattr(self, attribute_name)
715
+ if isinstance(attribute, PreTrainedTokenizerBase) and "auto_map" in attribute.init_kwargs:
716
+ del attribute.init_kwargs["auto_map"]
717
+
718
+ # Write processor_config.json (or equivalent)
719
+ output_processor_file = os.path.join(save_directory, PROCESSOR_NAME)
720
+ processor_dict = self.to_dict()
721
+
722
+ # If processor_dict only contains processor_class, we skip writing the file,
723
+ # matching the upstream behavior; otherwise we save it.
724
+ if set(processor_dict.keys()) != {"processor_class"}:
725
+ self.to_json_file(output_processor_file)
726
+ # noinspection PyUnresolvedReferences
727
+ logger.warning_once(f"processor saved in {output_processor_file}")
728
+
729
+ # If requested, upload the modified files to the Hub
730
+ if push_to_hub:
731
+ self._upload_modified_files(
732
+ save_directory,
733
+ repo_id,
734
+ files_timestamps,
735
+ commit_message=commit_message,
736
+ token=kwargs.get("token"),
737
+ create_pr=kwargs.get("create_pr", False),
738
+ revision=kwargs.get("revision"),
739
+ commit_description=kwargs.get("commit_description"),
740
+ )
741
+
742
+ if set(processor_dict.keys()) == {"processor_class"}:
743
+ return []
744
+ return [output_processor_file]
745
+
746
+ @classmethod
747
+ def _get_arguments_from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
748
+ subfolder = kwargs.pop("subfolder", None)
749
+ args = []
750
+ for attribute_name in cls.attributes:
751
+ class_name = getattr(cls, f"{attribute_name}_class")
752
+ attribute_class = cls.get_possibly_dynamic_module(class_name)
753
+ attribute_subfolder = os.path.join(subfolder, attribute_name) if subfolder else attribute_name
754
+
755
+ args.append(attribute_class.from_pretrained(
756
+ pretrained_model_name_or_path,
757
+ subfolder=attribute_subfolder,
758
+ **kwargs
759
+ ))
760
+
761
+ return args
762
+
763
+ def _upload_modified_files(
764
+ self,
765
+ working_dir: Union[str, os.PathLike],
766
+ repo_id: str,
767
+ files_timestamps: dict[str, float],
768
+ commit_message: Optional[str] = None,
769
+ token: Optional[Union[bool, str]] = None,
770
+ create_pr: bool = False,
771
+ revision: Optional[str] = None,
772
+ commit_description: Optional[str] = None,
773
+ ):
774
+ """
775
+ Uploads all modified files in `working_dir` to `repo_id`, based on `files_timestamps`.
776
+ """
777
+ working_dir = Path(working_dir)
778
+
779
+ if commit_message is None:
780
+ commit_message = "Upload CM3P processor"
781
+ modified_files = [
782
+ f
783
+ for f in working_dir.iterdir()
784
+ if str(f) not in files_timestamps or f.stat().st_mtime > files_timestamps[str(f)]
785
+ ]
786
+
787
+ # filter for actual files + folders at the root level
788
+ modified_files = [
789
+ f
790
+ for f in modified_files
791
+ if f.is_file() or f.is_dir()
792
+ ]
793
+
794
+ operations = []
795
+ # upload standalone files
796
+ for file in modified_files:
797
+ if file.is_dir():
798
+ # go over individual files of folder
799
+ for f in file.iterdir():
800
+ operations.append(
801
+ CommitOperationAdd(
802
+ path_or_fileobj=f, path_in_repo=f.relative_to(working_dir).as_posix()
803
+ )
804
+ )
805
+ else:
806
+ operations.append(
807
+ CommitOperationAdd(path_or_fileobj=file, path_in_repo=file.relative_to(working_dir).as_posix())
808
+ )
809
+
810
+ if revision is not None and not revision.startswith("refs/pr"):
811
+ try:
812
+ create_branch(repo_id=repo_id, branch=revision, token=token, exist_ok=True)
813
+ except HfHubHTTPError as e:
814
+ if e.response.status_code == 403 and create_pr:
815
+ # If we are creating a PR on a repo we don't have access to, we can't create the branch.
816
+ # so let's assume the branch already exists. If it's not the case, an error will be raised when
817
+ # calling `create_commit` below.
818
+ pass
819
+ else:
820
+ raise
821
+
822
+ logger.info(f"Uploading the following files to {repo_id}: {','.join([f.relative_to(working_dir).as_posix() for f in modified_files])}")
823
+ return create_commit(
824
+ repo_id=repo_id,
825
+ operations=operations,
826
+ commit_message=commit_message,
827
+ commit_description=commit_description,
828
+ token=token,
829
+ create_pr=create_pr,
830
+ revision=revision,
831
+ )
832
+
833
+ AutoProcessor.register(CM3PConfig, CM3PProcessor)
834
+
835
+ __all__ = ["CM3PProcessor", "get_metadata"]
processor_config.json ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "auto_map": {
3
+ "AutoProcessor": "processing_cm3p.CM3PProcessor"
4
+ },
5
+ "default_kwargs": {
6
+ "audio_kwargs": {
7
+ "audio_length_per_tok": 8,
8
+ "hop_length": 160,
9
+ "max_source_positions": 1600,
10
+ "pad_to_multiple_of": 256000,
11
+ "padding": false,
12
+ "sampling_rate": 16000,
13
+ "truncation": false,
14
+ "window_size": 400
15
+ },
16
+ "beatmap_kwargs": {
17
+ "max_length": 2000,
18
+ "padding": "longest",
19
+ "truncation": "longest_first",
20
+ "window_length_sec": 16.0,
21
+ "window_stride_sec": 16.0
22
+ },
23
+ "common_kwargs": {
24
+ "return_tensors": "pt"
25
+ },
26
+ "metadata_kwargs": {
27
+ "max_length": 128,
28
+ "padding": "longest",
29
+ "truncation": "longest_first"
30
+ }
31
+ },
32
+ "processor_class": "CM3PProcessor"
33
+ }
tokenization_cm3p.py ADDED
@@ -0,0 +1,808 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import json
3
+ from typing import Optional, Union, TypedDict
4
+
5
+ import numpy as np
6
+ from transformers import PreTrainedTokenizer, BatchEncoding, AutoTokenizer
7
+ from transformers.tokenization_utils_base import TruncationStrategy
8
+ from transformers.utils import PaddingStrategy
9
+
10
+ from .configuration_cm3p import CM3PBeatmapConfig, CM3PMetadataConfig
11
+ from .parsing_cm3p import Group, EventType, EVENT_TYPES_WITH_NEW_COMBO
12
+
13
+
14
+ class CM3PBeatmapTokenizer(PreTrainedTokenizer):
15
+ model_input_names: list[str] = ["input_ids", "attention_mask"]
16
+ vocab_files_names: dict[str, str] = {"vocab_file": "vocab.json"}
17
+
18
+ def __init__(
19
+ self,
20
+ vocab_file: Optional[str] = None,
21
+ min_time: int = 0,
22
+ max_time: int = 30000,
23
+ time_step: int = 10,
24
+ max_distance: int = 640,
25
+ distance_step: int = 4,
26
+ position_range: tuple[int, int, int, int] = (-256, 768, -256, 640),
27
+ position_step: int = 4,
28
+ position_split_axes: bool = True,
29
+ add_cls_token: bool = False,
30
+ separate_new_combo_token: bool = True,
31
+ **kwargs,
32
+ ):
33
+ self.min_time = min_time
34
+ self.max_time = max_time
35
+ self.time_step = time_step
36
+ self.max_distance = max_distance
37
+ self.distance_step = distance_step
38
+ self.position_range = position_range
39
+ self.position_step = position_step
40
+ self.position_split_axes = position_split_axes
41
+ self.add_cls_token = add_cls_token
42
+ self.separate_new_combo_token = separate_new_combo_token
43
+
44
+ self.audio_bos_token = "[AUDIO_BOS]"
45
+ self.audio_eos_token = "[AUDIO_EOS]"
46
+ self.audio_token = "[AUDIO]"
47
+
48
+ if vocab_file is None:
49
+ self.vocab = self._build_vocab_from_config()
50
+ else:
51
+ with open(vocab_file, 'r', encoding='utf-8') as f:
52
+ self.vocab = json.load(f)
53
+
54
+ self.ids_to_tokens = {i: t for t, i in self.vocab.items()}
55
+ super().__init__(
56
+ bos_token=kwargs.pop("bos_token", "[BOS]"),
57
+ eos_token=kwargs.pop("eos_token", "[EOS]"),
58
+ unk_token=kwargs.pop("unk_token", "[UNK]"),
59
+ sep_token=kwargs.pop("sep_token", "[SEP]"),
60
+ pad_token=kwargs.pop("pad_token", "[PAD]"),
61
+ cls_token=kwargs.pop("cls_token", "[CLS]"),
62
+ mask_token=kwargs.pop("mask_token", "[MASK]"),
63
+ additional_special_tokens=kwargs.pop("additional_special_tokens", [
64
+ self.audio_bos_token,
65
+ self.audio_eos_token,
66
+ self.audio_token,
67
+ ]),
68
+ min_time=min_time,
69
+ max_time=max_time,
70
+ time_step=time_step,
71
+ max_distance=max_distance,
72
+ distance_step=distance_step,
73
+ position_range=position_range,
74
+ position_step=position_step,
75
+ position_split_axes=position_split_axes,
76
+ add_cls_token=add_cls_token,
77
+ separate_new_combo_token=separate_new_combo_token,
78
+ **kwargs
79
+ )
80
+
81
+ def _build_vocab_from_config(self):
82
+ vocab = []
83
+
84
+ for event_type in EventType:
85
+ vocab.append(f"[{event_type.value.upper()}]")
86
+
87
+ if not self.separate_new_combo_token:
88
+ for event_type in EVENT_TYPES_WITH_NEW_COMBO:
89
+ vocab.append(f"[{event_type.value.upper()}_NEW_COMBO]")
90
+
91
+ for time in np.arange(self.min_time, self.max_time + 1e-5, self.time_step):
92
+ vocab.append(f"[TIME_SHIFT_{int(time)}]")
93
+
94
+ for snapping in range(0, 17):
95
+ vocab.append(f"[SNAPPING_{snapping}]")
96
+
97
+ for distance in range(0, self.max_distance + 1):
98
+ vocab.append(f"[DISTANCE_{distance}]")
99
+
100
+ if self.position_split_axes:
101
+ for x in np.arange(self.position_range[0], self.position_range[1] + 1e-5, self.position_step):
102
+ vocab.append(f"[POS_X_{int(x)}]")
103
+ for y in np.arange(self.position_range[2], self.position_range[3] + 1e-5, self.position_step):
104
+ vocab.append(f"[POS_Y_{int(y)}]")
105
+ else:
106
+ for x in np.arange(self.position_range[0], self.position_range[1] + 1e-5, self.position_step):
107
+ for y in np.arange(self.position_range[2], self.position_range[3] + 1e-5, self.position_step):
108
+ vocab.append(f"[POS_{int(x)}_{int(y)}]")
109
+
110
+ for mania_column in range(1, 19):
111
+ vocab.append(f"[MANIA_COLUMN_{mania_column}]")
112
+
113
+ for scroll_speed in np.arange(0.0, 10.0 + 1e-5, 0.01):
114
+ vocab.append(f"[SCROLL_SPEED_{scroll_speed:.2f}]")
115
+
116
+ if self.separate_new_combo_token:
117
+ vocab.append("[NEW_COMBO]")
118
+
119
+ for hitsound in range(8):
120
+ for sampleset in range(1, 4):
121
+ for additions in range(1, 4):
122
+ vocab.append(f"[HITSOUND_{(hitsound << 1)}_{sampleset}_{additions}]")
123
+
124
+ for volume in range(101):
125
+ vocab.append(f"[VOLUME_{volume}]")
126
+
127
+ return {token: idx for idx, token in enumerate(vocab)}
128
+
129
+ def _tokenize_time_shift(self, time: int):
130
+ time = np.clip(time, self.min_time, self.max_time)
131
+ time = round(time / self.time_step) * self.time_step
132
+ return f"[TIME_SHIFT_{int(time)}]"
133
+
134
+ def _tokenize_distance(self, distance: int):
135
+ distance = np.clip(distance, 0, self.max_distance)
136
+ distance = round(distance / self.distance_step) * self.distance_step
137
+ return f"[DISTANCE_{distance}]"
138
+
139
+ def _tokenize_position(self, pos_x: int, pos_y: int):
140
+ pos_x = np.clip(pos_x, self.position_range[0], self.position_range[1])
141
+ pos_y = np.clip(pos_y, self.position_range[2], self.position_range[3])
142
+ pos_x = round(pos_x / self.position_step) * self.position_step
143
+ pos_y = round(pos_y / self.position_step) * self.position_step
144
+
145
+ if self.position_split_axes:
146
+ yield f"[POS_X_{int(pos_x)}]"
147
+ yield f"[POS_Y_{int(pos_y)}]"
148
+ else:
149
+ yield f"[POS_{int(pos_x)}_{int(pos_y)}]"
150
+
151
+ def _tokenize_mania_column(self, mania_column: int):
152
+ mania_column = np.clip(mania_column, 1, 18)
153
+ return f"[MANIA_COLUMN_{mania_column}]"
154
+
155
+ def _tokenize_scroll_speed(self, scroll_speed: float):
156
+ scroll_speed = np.clip(scroll_speed, 0.0, 10.0)
157
+ scroll_speed = round(scroll_speed / 0.01) * 0.01
158
+ return f"[SCROLL_SPEED_{scroll_speed:.2f}]"
159
+
160
+ def _tokenize_hitsound(self, hitsound: int, sampleset: int, addition: int):
161
+ hitsound = np.clip(hitsound >> 1, 0, 7) << 1
162
+ sampleset = np.clip(sampleset, 1, 3)
163
+ addition = np.clip(addition, 1, 3)
164
+ return f"[HITSOUND_{hitsound}_{sampleset}_{addition}]"
165
+
166
+ def _tokenize_groups(
167
+ self,
168
+ groups: list[Group],
169
+ window_start_ms: Optional[int] = None,
170
+ **_
171
+ ):
172
+ window_start_ms = window_start_ms or 0
173
+ tokens = []
174
+ if self.add_cls_token:
175
+ tokens.append(self.cls_token)
176
+ tokens.append(self.bos_token)
177
+
178
+ for group in groups:
179
+ if group.new_combo and not self.separate_new_combo_token and group.event_type in EVENT_TYPES_WITH_NEW_COMBO:
180
+ tokens.append(f"[{group.event_type.value.upper()}_NEW_COMBO]")
181
+ else:
182
+ tokens.append(f"[{group.event_type.value.upper()}]")
183
+ if group.has_time:
184
+ tokens.append(self._tokenize_time_shift(group.time - window_start_ms))
185
+ if group.snapping is not None:
186
+ tokens.append(f"[SNAPPING_{group.snapping}]")
187
+ if group.distance is not None:
188
+ tokens.append(self._tokenize_distance(group.distance))
189
+ if group.x is not None and group.y is not None:
190
+ tokens.extend(self._tokenize_position(group.x, group.y))
191
+ if group.mania_column is not None:
192
+ tokens.append(self._tokenize_mania_column(group.mania_column))
193
+ if group.new_combo and self.separate_new_combo_token:
194
+ tokens.append("[NEW_COMBO]")
195
+ if group.scroll_speed is not None:
196
+ tokens.append(self._tokenize_scroll_speed(group.scroll_speed))
197
+ for h, s, a, v, in zip(
198
+ group.hitsounds,
199
+ group.samplesets,
200
+ group.additions,
201
+ group.volumes,
202
+ ):
203
+ tokens.append(self._tokenize_hitsound(h, s, a))
204
+ tokens.append(f"[VOLUME_{v}]")
205
+
206
+ tokens.append(self.eos_token)
207
+ return tokens
208
+
209
+ def _encode_single(
210
+ self,
211
+ groups: Optional[Union[list[Group]]] = None,
212
+ window_start_ms: Optional[int] = None,
213
+ num_audio_tokens: Optional[int] = None,
214
+ ):
215
+ token_strings = self._tokenize_groups(groups, window_start_ms=window_start_ms)
216
+ token_ids = self.convert_tokens_to_ids(token_strings)
217
+
218
+ if num_audio_tokens is not None and num_audio_tokens > 0:
219
+ audio_tokens = [self.audio_bos_token] + [self.audio_token] * num_audio_tokens + [self.audio_eos_token]
220
+ token_ids = self.convert_tokens_to_ids(audio_tokens) + token_ids
221
+
222
+ return token_ids
223
+
224
+ def __call__(
225
+ self,
226
+ groups: Optional[Union[list[Group], list[list[Group]]]] = None,
227
+ window_start_ms: Optional[Union[int, list[int]]] = None,
228
+ num_audio_tokens: Optional[Union[int, list[int]]] = None,
229
+ padding: PaddingStrategy = PaddingStrategy.LONGEST,
230
+ truncation: TruncationStrategy = TruncationStrategy.LONGEST_FIRST,
231
+ **kwargs
232
+ ) -> BatchEncoding:
233
+ if len(groups) == 0:
234
+ raise ValueError("Input groups list is empty.")
235
+
236
+ if isinstance(groups, list) and all(isinstance(g, Group) for g in groups):
237
+ token_ids = self._encode_single(
238
+ groups=groups,
239
+ window_start_ms=window_start_ms,
240
+ num_audio_tokens=num_audio_tokens,
241
+ )
242
+ encoding = self.prepare_for_model(
243
+ token_ids,
244
+ padding=padding,
245
+ truncation=truncation,
246
+ **kwargs,
247
+ )
248
+ elif isinstance(groups, list):
249
+ if num_audio_tokens is None:
250
+ num_audio_tokens = [None] * len(groups)
251
+
252
+ if window_start_ms is None:
253
+ window_start_ms = [None] * len(groups)
254
+
255
+ if len(groups) != len(num_audio_tokens):
256
+ raise ValueError("Number of num_audio_tokens inputs must match the number of sequences.")
257
+
258
+ if len(window_start_ms) != len(groups):
259
+ raise ValueError("Number of window start times must match the number of sequences.")
260
+
261
+ all_token_ids = []
262
+ for g, w, a in zip(groups, window_start_ms, num_audio_tokens):
263
+ token_ids = self._encode_single(
264
+ groups=g,
265
+ window_start_ms=w,
266
+ num_audio_tokens=a,
267
+ )
268
+ all_token_ids.append((token_ids, None))
269
+
270
+ encoding = self._batch_prepare_for_model(
271
+ all_token_ids,
272
+ padding_strategy=PaddingStrategy(padding),
273
+ truncation_strategy=TruncationStrategy(truncation),
274
+ **kwargs,
275
+ )
276
+ else:
277
+ raise ValueError("Input must be a list of Group objects or a single Group object.")
278
+
279
+ return encoding
280
+
281
+ @property
282
+ def vocab_size(self):
283
+ return len(self.vocab) + len(self._added_tokens_encoder)
284
+
285
+ def get_vocab(self):
286
+ return self.vocab | self._added_tokens_encoder
287
+
288
+ def _convert_token_to_id(self, token):
289
+ return self.vocab.get(token, self.vocab.get(self.unk_token))
290
+
291
+ def _convert_id_to_token(self, index):
292
+ return self.ids_to_tokens.get(index, self.unk_token)
293
+
294
+ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> tuple[str]:
295
+ if not save_directory:
296
+ raise ValueError("The save_directory must be specified.")
297
+
298
+ vocab_file = f"{save_directory}/{filename_prefix or ''}vocab.json"
299
+ with open(vocab_file, 'w', encoding='utf-8') as f:
300
+ json.dump(self.vocab, f, ensure_ascii=False)
301
+
302
+ return (vocab_file,)
303
+
304
+
305
+ class CM3PMetadata(TypedDict, total=False):
306
+ """
307
+ Metadata fields for a beatmap.
308
+
309
+ difficulty: Star rating, unitless (osu! difficulty)
310
+ year: Year of beatmap creation (YYYY)
311
+ mode: Game mode ID or name (e.g., "osu", "mania")
312
+ mapper: Beatmap creator's ID or username
313
+ cs: Circle size (osu!std), unitless
314
+ hitsounded: Whether the beatmap is hitsounded (True/False)
315
+ song_length: Song length in seconds
316
+ song_position: Relative position in song [0.0-1.0], unitless
317
+ global_sv: Global scroll velocity (osu!mania), multiplier
318
+ mania_keycount: Number of keys in osu!mania [1-18]
319
+ hold_note_ratio: Ratio of hold notes [0.0-1.0], unitless
320
+ scroll_speed_ratio: Ratio of scroll speed changes [0.0-1.0], unitless
321
+ tags: List of beatmap tag IDs or names
322
+ """
323
+ difficulty: float # Star rating, unitless (osu! difficulty)
324
+ year: int # Year of beatmap creation (YYYY)
325
+ mode: Union[int, str] # Game mode ID or name (e.g., "osu", "mania")
326
+ status: Union[int, str] # Beatmap status (e.g., "ranked", "approved", "loved", "pending", "graveyard")
327
+ mapper: Union[int, str] # Beatmap creator's ID or username
328
+ cs: float # Circle size (osu!std), unitless
329
+ hitsounded: bool # Whether the beatmap is hitsounded (True/False)
330
+ song_length: float # Song length in seconds
331
+ song_position: float # Relative position in song [0.0-1.0], unitless
332
+ global_sv: float # Global slider velocity (osu!standard/catch), multiplier
333
+ mania_keycount: int # Number of keys in osu!mania [1-18]
334
+ hold_note_ratio: float # Ratio of hold notes [0.0-1.0], unitless
335
+ scroll_speed_ratio: float # Ratio of scroll speed changes [0.0-1.0], unitless
336
+ tags: list[Union[int, str]] # List of beatmap tag IDs or names
337
+
338
+
339
+ def merge_metadata_dicts(m1, m2):
340
+ if m1 is None:
341
+ return m2
342
+ if m2 is None:
343
+ return m1
344
+ merged = {}
345
+ for key in CM3PMetadata.__annotations__.keys():
346
+ v1 = m1.get(key, None)
347
+ v2 = m2.get(key, None)
348
+ merged[key] = v2 if v1 is None else v1
349
+ return CM3PMetadata(**merged)
350
+
351
+
352
+ class CM3PMetadataTokenizer(PreTrainedTokenizer):
353
+ model_input_names: list[str] = ["input_ids", "attention_mask"]
354
+ vocab_files_names: dict[str, str] = {"vocab_file": "vocab.json"}
355
+
356
+ def __init__(
357
+ self,
358
+ vocab_file: Optional[str] = None,
359
+ modes: Optional[dict[int, str]] = None,
360
+ statuses: Optional[dict[int, str]] = None,
361
+ mappers: Optional[dict[int, str]] = None,
362
+ tags: Optional[dict[int, dict]] = None,
363
+ min_difficculty: float = 0.0,
364
+ max_difficulty: float = 14.0,
365
+ difficulty_step: float = 0.1,
366
+ min_year: int = 2000,
367
+ max_year: int = 2023,
368
+ max_song_length: int = 600,
369
+ song_length_step: int = 10,
370
+ song_position_step: float = 0.01,
371
+ global_sv_step: float = 0.01,
372
+ hold_note_ratio_step: float = 0.1,
373
+ scroll_speed_ratio_step: float = 0.1,
374
+ add_cls_token: bool = False,
375
+ **kwargs,
376
+ ):
377
+ self.min_difficulty = min_difficculty
378
+ self.max_difficulty = max_difficulty
379
+ self.difficulty_step = difficulty_step
380
+ self.min_year = min_year
381
+ self.max_year = max_year
382
+ self.max_song_length = max_song_length
383
+ self.song_length_step = song_length_step
384
+ self.song_position_step = song_position_step
385
+ self.global_sv_step = global_sv_step
386
+ self.hold_note_ratio_step = hold_note_ratio_step
387
+ self.scroll_speed_ratio_step = scroll_speed_ratio_step
388
+ self.add_cls_token = add_cls_token
389
+
390
+ self.difficulty_unk_token = "[DIFFICULTY_UNK]"
391
+ self.year_unk_token = "[YEAR_UNK]"
392
+ self.mode_unk_token = "[MODE_UNK]"
393
+ self.status_unk_token = "[STATUS_UNK]"
394
+ self.mapper_unk_token = "[MAPPER_UNK]"
395
+ self.cs_unk_token = "[CS_UNK]"
396
+ self.hitsounded_unk_token = "[HITSOUNDED_UNK]"
397
+ self.song_length_unk_token = "[SONG_LENGTH_UNK]"
398
+ self.song_position_unk_token = "[SONG_POSITION_UNK]"
399
+ self.global_sv_unk_token = "[GLOBAL_SV_UNK]"
400
+ self.mania_keycount_unk_token = "[MANIA_KEYCOUNT_UNK]"
401
+ self.hold_note_ratio_unk_token = "[HOLD_NOTE_RATIO_UNK]"
402
+ self.scroll_speed_ratio_unk_token = "[SCROLL_SPEED_RATIO_UNK]"
403
+ self.tag_unk_token = "[TAG_UNK]"
404
+
405
+ self.modes = modes or {}
406
+ self.statuses = statuses or {}
407
+ self.mappers = mappers or {}
408
+ self.tags = tags or {}
409
+ self.mode_names_to_ids = {v: k for k, v in self.modes.items()}
410
+ self.mode_ids_to_names = self.modes
411
+ self.status_names_to_ids = {v: k for k, v in self.statuses.items()}
412
+ self.status_ids_to_names = self.statuses
413
+ self.mapper_names_to_ids = {v: k for k, v in self.mappers.items()}
414
+ self.mapper_ids_to_names = self.mappers
415
+ self.tag_names_to_ids = {v['name']: k for k, v in self.tags.items()}
416
+ self.tag_ids_to_names = {k: v['name'] for k, v in self.tags.items()}
417
+
418
+ if vocab_file is None:
419
+ self.vocab = self._build_vocab_from_config()
420
+ else:
421
+ with open(vocab_file, 'r', encoding='utf-8') as f:
422
+ self.vocab = json.load(f)
423
+
424
+ self.ids_to_tokens = {i: t for t, i in self.vocab.items()}
425
+
426
+ super().__init__(
427
+ bos_token=kwargs.pop("bos_token", "[BOS]"),
428
+ eos_token=kwargs.pop("eos_token", "[EOS]"),
429
+ pad_token=kwargs.pop("pad_token", "[PAD]"),
430
+ cls_token=kwargs.pop("cls_token", "[CLS]"),
431
+ additional_special_tokens=kwargs.pop("additional_special_tokens", [
432
+ self.difficulty_unk_token,
433
+ self.year_unk_token,
434
+ self.mode_unk_token,
435
+ self.status_unk_token,
436
+ self.mapper_unk_token,
437
+ self.cs_unk_token,
438
+ self.hitsounded_unk_token,
439
+ self.song_length_unk_token,
440
+ self.song_position_unk_token,
441
+ self.global_sv_unk_token,
442
+ self.mania_keycount_unk_token,
443
+ self.hold_note_ratio_unk_token,
444
+ self.scroll_speed_ratio_unk_token,
445
+ self.tag_unk_token,
446
+ ]),
447
+ modes=modes,
448
+ statuses=statuses,
449
+ mappers=mappers,
450
+ tags=tags,
451
+ min_difficculty=min_difficculty,
452
+ max_difficulty=max_difficulty,
453
+ difficulty_step=difficulty_step,
454
+ min_year=min_year,
455
+ max_year=max_year,
456
+ max_song_length=max_song_length,
457
+ song_length_step=song_length_step,
458
+ song_position_step=song_position_step,
459
+ global_sv_step=global_sv_step,
460
+ hold_note_ratio_step=hold_note_ratio_step,
461
+ scroll_speed_ratio_step=scroll_speed_ratio_step,
462
+ add_cls_token=add_cls_token,
463
+ **kwargs
464
+ )
465
+
466
+ def _build_vocab_from_config(self):
467
+ vocab = []
468
+
469
+ for difficulty in np.arange(self.min_difficulty, self.max_difficulty + 1e-5, self.difficulty_step):
470
+ vocab.append(f"[DIFFICULTY_{difficulty:.1f}]")
471
+
472
+ for year in range(self.min_year, self.max_year + 1):
473
+ vocab.append(f"[YEAR_{year}]")
474
+
475
+ for mode in self.mode_ids_to_names.values():
476
+ vocab.append(f"[MODE_{str(mode)}]")
477
+
478
+ for status in self.status_ids_to_names.values():
479
+ vocab.append(f"[STATUS_{str(status)}]")
480
+
481
+ for mapper in self.mapper_ids_to_names.keys():
482
+ vocab.append(f"[MAPPER_{str(mapper)}]")
483
+
484
+ for cs in np.arange(0.0, 10.0 + 1e-5, 0.1):
485
+ vocab.append(f"[CS_{cs:.1f}]")
486
+
487
+ for hitsounded in [True, False]:
488
+ vocab.append(f"[HITSOUNDED_{str(hitsounded).upper()}]")
489
+
490
+ for song_length in np.arange(0, self.max_song_length + 1e-5, self.song_length_step):
491
+ vocab.append(f"[SONG_LENGTH_{int(song_length)}]")
492
+
493
+ for song_position in np.arange(0.0, 1.0 + 1e-5, self.song_position_step):
494
+ vocab.append(f"[SONG_POSITION_{song_position:.2f}]")
495
+
496
+ for global_sv in np.arange(0.4, 3.6 + 1e-5, self.global_sv_step):
497
+ vocab.append(f"[GLOBAL_SV_{global_sv:.2f}]")
498
+
499
+ for mania_keycount in range(1, 19):
500
+ vocab.append(f"[MANIA_KEYCOUNT_{mania_keycount}]")
501
+
502
+ for hold_note_ratio in np.arange(0.0, 1.0 + 1e-5, self.hold_note_ratio_step):
503
+ vocab.append(f"[HOLD_NOTE_RATIO_{hold_note_ratio:.1f}]")
504
+
505
+ for scroll_speed_ratio in np.arange(0.0, 1.0 + 1e-5, self.scroll_speed_ratio_step):
506
+ vocab.append(f"[SCROLL_SPEED_RATIO_{scroll_speed_ratio:.1f}]")
507
+
508
+ for tag in self.tag_ids_to_names.values():
509
+ vocab.append(f"[TAG_{tag}]")
510
+
511
+ return {token: idx for idx, token in enumerate(vocab)}
512
+
513
+ def _tokenize_difficulty(self, metadata: CM3PMetadata):
514
+ difficulty = metadata.get('difficulty', None)
515
+ if difficulty is None:
516
+ return self.difficulty_unk_token
517
+ difficulty = np.clip(difficulty, self.min_difficulty, self.max_difficulty)
518
+ difficulty = round(difficulty / self.difficulty_step) * self.difficulty_step
519
+ return f"[DIFFICULTY_{difficulty:.1f}]"
520
+
521
+ def _tokenize_year(self, metadata: CM3PMetadata):
522
+ year = metadata.get('year', None)
523
+ if year is None:
524
+ return self.year_unk_token
525
+ year = np.clip(year, self.min_year, self.max_year)
526
+ return f"[YEAR_{year}]"
527
+
528
+ def _tokenize_mode(self, metadata: CM3PMetadata):
529
+ mode_str = metadata.get('mode', None)
530
+ if isinstance(mode_str, int):
531
+ mode_str = self.mode_ids_to_names.get(mode_str, None)
532
+ if mode_str is None or mode_str not in self.mode_names_to_ids:
533
+ return self.mode_unk_token
534
+ return f"[MODE_{str(mode_str)}]"
535
+
536
+ def _tokenize_status(self, metadata: CM3PMetadata):
537
+ status_str = metadata.get('status', None)
538
+ if isinstance(status_str, int):
539
+ status_str = self.status_ids_to_names.get(status_str, None)
540
+ if status_str is None or status_str not in self.status_names_to_ids:
541
+ return self.status_unk_token
542
+ return f"[STATUS_{str(status_str)}]"
543
+
544
+ def _tokenize_mapper(self, metadata: CM3PMetadata):
545
+ mapper_id = metadata.get('mapper', None)
546
+ if isinstance(mapper_id, str):
547
+ mapper_id = self.mapper_names_to_ids.get(mapper_id, None)
548
+ if mapper_id is None or mapper_id not in self.mapper_ids_to_names:
549
+ return self.mapper_unk_token
550
+ return f"[MAPPER_{str(mapper_id)}]"
551
+
552
+ def _tokenize_cs(self, metadata: CM3PMetadata):
553
+ cs = metadata.get('cs', None)
554
+ if cs is None:
555
+ return self.cs_unk_token
556
+ cs = np.clip(cs, 0.0, 10.0)
557
+ cs = round(cs / 0.1) * 0.1
558
+ return f"[CS_{cs:.1f}]"
559
+
560
+ def _tokenize_hitsounded(self, metadata: CM3PMetadata):
561
+ hitsounded = metadata.get('hitsounded', None)
562
+ if hitsounded is None:
563
+ return self.hitsounded_unk_token
564
+ return f"[HITSOUNDED_{str(hitsounded).upper()}]"
565
+
566
+ def _tokenize_song_length(self, metadata: CM3PMetadata):
567
+ song_length = metadata.get('song_length', None)
568
+ if song_length is None:
569
+ return self.song_length_unk_token
570
+ song_length = np.clip(song_length, 0, self.max_song_length)
571
+ song_length = round(song_length / self.song_length_step) * self.song_length_step
572
+ return f"[SONG_LENGTH_{int(song_length)}]"
573
+
574
+ def _tokenize_song_position(self, metadata: CM3PMetadata):
575
+ song_position = metadata.get('song_position', None)
576
+ if song_position is None:
577
+ return self.song_position_unk_token
578
+ song_position = np.clip(song_position, 0.0, 1.0)
579
+ song_position = round(song_position / self.song_position_step) * self.song_position_step
580
+ return f"[SONG_POSITION_{song_position:.2f}]"
581
+
582
+ def _tokenize_global_sv(self, metadata: CM3PMetadata):
583
+ global_sv = metadata.get('global_sv', None)
584
+ if global_sv is None:
585
+ return self.global_sv_unk_token
586
+ global_sv = np.clip(global_sv, 0.4, 3.6)
587
+ global_sv = round(global_sv / self.global_sv_step) * self.global_sv_step
588
+ return f"[GLOBAL_SV_{global_sv:.2f}]"
589
+
590
+ def _tokenize_mania_keycount(self, metadata: CM3PMetadata):
591
+ mania_keycount = metadata.get('mania_keycount', None)
592
+ if mania_keycount is None:
593
+ return self.mania_keycount_unk_token
594
+ mania_keycount = int(mania_keycount)
595
+ mania_keycount = np.clip(mania_keycount, 1, 18)
596
+ return f"[MANIA_KEYCOUNT_{mania_keycount}]"
597
+
598
+ def _tokenize_hold_note_ratio(self, metadata: CM3PMetadata):
599
+ hold_note_ratio = metadata.get('hold_note_ratio', None)
600
+ if hold_note_ratio is None:
601
+ return self.hold_note_ratio_unk_token
602
+ hold_note_ratio = np.clip(hold_note_ratio, 0.0, 1.0)
603
+ hold_note_ratio = round(hold_note_ratio / self.hold_note_ratio_step) * self.hold_note_ratio_step
604
+ return f"[HOLD_NOTE_RATIO_{hold_note_ratio:.1f}]"
605
+
606
+ def _tokenize_scroll_speed_ratio(self, metadata: CM3PMetadata):
607
+ scroll_speed_ratio = metadata.get('scroll_speed_ratio', None)
608
+ if scroll_speed_ratio is None:
609
+ return self.scroll_speed_ratio_unk_token
610
+ scroll_speed_ratio = np.clip(scroll_speed_ratio, 0.0, 1.0)
611
+ scroll_speed_ratio = round(scroll_speed_ratio / self.scroll_speed_ratio_step) * self.scroll_speed_ratio_step
612
+ return f"[SCROLL_SPEED_RATIO_{scroll_speed_ratio:.1f}]"
613
+
614
+ def _validate_tags(self, tags):
615
+ if tags is None:
616
+ return None
617
+ new_tags = []
618
+ for tag in tags:
619
+ if isinstance(tag, str) and tag in self.tag_names_to_ids:
620
+ new_tags.append(tag)
621
+ elif tag in self.tag_ids_to_names:
622
+ new_tags.append(self.tag_ids_to_names[tag])
623
+ return new_tags
624
+
625
+ def _tokenize_tags(self, metadata: CM3PMetadata):
626
+ tags = metadata.get('tags', None)
627
+ valid_tags = self._validate_tags(tags)
628
+ if not valid_tags:
629
+ return [self.tag_unk_token]
630
+ return [f"[TAG_{tag}]" for tag in valid_tags]
631
+
632
+ def _tokenize_metadata(self, metadata: CM3PMetadata):
633
+ tokens = []
634
+ if self.add_cls_token:
635
+ tokens.append(self.cls_token)
636
+ tokens.extend([
637
+ self.bos_token,
638
+ self._tokenize_difficulty(metadata),
639
+ self._tokenize_year(metadata),
640
+ self._tokenize_mode(metadata),
641
+ self._tokenize_status(metadata),
642
+ self._tokenize_mapper(metadata),
643
+ self._tokenize_cs(metadata),
644
+ self._tokenize_hitsounded(metadata),
645
+ self._tokenize_song_length(metadata),
646
+ self._tokenize_song_position(metadata),
647
+ self._tokenize_global_sv(metadata),
648
+ self._tokenize_mania_keycount(metadata),
649
+ self._tokenize_hold_note_ratio(metadata),
650
+ self._tokenize_scroll_speed_ratio(metadata),
651
+ ])
652
+ tokens.extend(self._tokenize_tags(metadata))
653
+ tokens.append(self.eos_token)
654
+ return tokens
655
+
656
+ def __call__(
657
+ self,
658
+ metadata: Optional[Union[CM3PMetadata, list[CM3PMetadata]]] = None,
659
+ padding: PaddingStrategy = PaddingStrategy.LONGEST,
660
+ truncation: TruncationStrategy = TruncationStrategy.DO_NOT_TRUNCATE,
661
+ max_length: Optional[int] = None,
662
+ return_tensors: Optional[str] = "pt",
663
+ **kwargs
664
+ ) -> BatchEncoding:
665
+ if isinstance(metadata, dict):
666
+ token_strings = self._tokenize_metadata(metadata)
667
+ token_ids = self.convert_tokens_to_ids(token_strings)
668
+ return self.prepare_for_model(
669
+ token_ids,
670
+ padding=padding,
671
+ truncation=truncation,
672
+ max_length=max_length,
673
+ return_tensors=return_tensors,
674
+ **kwargs,
675
+ )
676
+ elif isinstance(metadata, list):
677
+ all_token_ids = []
678
+ for m in metadata:
679
+ token_strings = self._tokenize_metadata(m)
680
+ token_ids = self.convert_tokens_to_ids(token_strings)
681
+ all_token_ids.append((token_ids, None))
682
+
683
+ return self._batch_prepare_for_model(
684
+ all_token_ids,
685
+ padding_strategy=PaddingStrategy(padding),
686
+ truncation_strategy=TruncationStrategy(truncation),
687
+ max_length=max_length,
688
+ return_tensors=return_tensors,
689
+ )
690
+
691
+ def metadata_variations(self, metadata: CM3PMetadata, num_variations: int = 1000) -> tuple[CM3PMetadata, int]:
692
+ def year_variations():
693
+ min_year = max(2007, self.min_year)
694
+ if metadata["year"] is None or (min_year > metadata["year"] or metadata["year"] > self.max_year):
695
+ return
696
+ for year in range(min_year, self.max_year + 1):
697
+ if year != metadata["year"]:
698
+ new_m = copy.deepcopy(metadata)
699
+ new_m["year"] = year
700
+ yield new_m, 1
701
+
702
+ def status_variations():
703
+ if metadata["status"] is None:
704
+ return
705
+ current_status = self.status_ids_to_names.get(metadata["status"], None) or metadata["status"]
706
+ if current_status not in self.status_names_to_ids:
707
+ return
708
+ for status in self.status_ids_to_names.values():
709
+ if status != current_status:
710
+ new_m = copy.deepcopy(metadata)
711
+ new_m["status"] = status
712
+ yield new_m, 2
713
+
714
+ def tags_variations():
715
+ # Replace/add/remove some tags
716
+ if metadata["tags"] is None or len(metadata["tags"]) <= 0:
717
+ return
718
+ current_tags = self._validate_tags(metadata["tags"])
719
+ if len(current_tags) <= 0:
720
+ return
721
+ for tag in self.tag_ids_to_names.values():
722
+ if tag not in current_tags:
723
+ new_m = copy.deepcopy(metadata)
724
+ new_m["tags"][np.random.randint(0, len(new_m["tags"]))] = tag
725
+ yield new_m, 3
726
+ for tag in self.tag_ids_to_names.values():
727
+ if tag not in current_tags:
728
+ new_m = copy.deepcopy(metadata)
729
+ new_m["tags"].insert(np.random.randint(0, len(new_m["tags"]) + 1), tag)
730
+ yield new_m, 3
731
+ if len(current_tags) <= 1:
732
+ return
733
+ for tag in current_tags:
734
+ new_m = copy.deepcopy(metadata)
735
+ new_tags = [t for t in current_tags if t != tag]
736
+ new_m["tags"] = new_tags
737
+ yield new_m, 3
738
+
739
+ def mapper_variations():
740
+ if metadata['mapper'] is None:
741
+ return
742
+ current_mapper = self.mapper_names_to_ids.get(metadata["mapper"], None) or metadata["mapper"]
743
+ mapper_variations = list(self.mapper_ids_to_names.keys())
744
+ if current_mapper in self.mapper_ids_to_names:
745
+ mapper_variations.remove(current_mapper)
746
+ # Randomly sample mappers to avoid too many variations
747
+ np.random.shuffle(mapper_variations)
748
+ for mapper in mapper_variations:
749
+ new_m = copy.deepcopy(metadata)
750
+ new_m["mapper"] = mapper
751
+ yield new_m, 4
752
+
753
+ def padding_variations():
754
+ while True:
755
+ yield CM3PMetadata(), -1
756
+
757
+ # Add variations with one field changed at a time
758
+ current_num_variations = 0
759
+ workers = [
760
+ year_variations(),
761
+ status_variations(),
762
+ tags_variations(),
763
+ mapper_variations(),
764
+ ]
765
+ padding_iterable = padding_variations()
766
+
767
+ index = 0
768
+ while current_num_variations < num_variations and len(workers) > 0:
769
+ try:
770
+ index = index % len(workers)
771
+ item = workers[index].__next__()
772
+ index += 1
773
+ current_num_variations += 1
774
+ yield item
775
+ except StopIteration:
776
+ workers.remove(workers[index])
777
+
778
+ while current_num_variations < num_variations:
779
+ current_num_variations += 1
780
+ yield padding_iterable.__next__()
781
+
782
+ @property
783
+ def vocab_size(self):
784
+ return len(self.vocab) + len(self._added_tokens_encoder)
785
+
786
+ def get_vocab(self):
787
+ return self.vocab | self._added_tokens_encoder
788
+
789
+ def _convert_token_to_id(self, token):
790
+ return self.vocab.get(token, self.vocab.get(self.unk_token))
791
+
792
+ def _convert_id_to_token(self, index):
793
+ return self.ids_to_tokens.get(index, self.unk_token)
794
+
795
+ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> tuple[str]:
796
+ if not save_directory:
797
+ raise ValueError("The save_directory must be specified.")
798
+
799
+ vocab_file = f"{save_directory}/{filename_prefix or ''}vocab.json"
800
+ with open(vocab_file, 'w', encoding='utf-8') as f:
801
+ json.dump(self.vocab, f, ensure_ascii=False)
802
+
803
+ return (vocab_file,)
804
+
805
+ AutoTokenizer.register(CM3PBeatmapConfig, CM3PBeatmapTokenizer)
806
+ AutoTokenizer.register(CM3PMetadataConfig, CM3PMetadataTokenizer)
807
+
808
+ __all__ = ["CM3PBeatmapTokenizer", "CM3PMetadataTokenizer", "CM3PMetadata", "merge_metadata_dicts"]