OliBomby commited on
Commit
8488397
·
verified ·
1 Parent(s): 26cc694

fix metadata tokenizer id to name mappings having string keys

Browse files
Files changed (1) hide show
  1. tokenization_cm3p.py +808 -808
tokenization_cm3p.py CHANGED
@@ -1,808 +1,808 @@
1
- import copy
2
- import json
3
- from typing import Optional, Union, TypedDict
4
-
5
- import numpy as np
6
- from transformers import PreTrainedTokenizer, BatchEncoding, AutoTokenizer
7
- from transformers.tokenization_utils_base import TruncationStrategy
8
- from transformers.utils import PaddingStrategy
9
-
10
- from .configuration_cm3p import CM3PBeatmapConfig, CM3PMetadataConfig
11
- from .parsing_cm3p import Group, EventType, EVENT_TYPES_WITH_NEW_COMBO
12
-
13
-
14
- class CM3PBeatmapTokenizer(PreTrainedTokenizer):
15
- model_input_names: list[str] = ["input_ids", "attention_mask"]
16
- vocab_files_names: dict[str, str] = {"vocab_file": "vocab.json"}
17
-
18
- def __init__(
19
- self,
20
- vocab_file: Optional[str] = None,
21
- min_time: int = 0,
22
- max_time: int = 30000,
23
- time_step: int = 10,
24
- max_distance: int = 640,
25
- distance_step: int = 4,
26
- position_range: tuple[int, int, int, int] = (-256, 768, -256, 640),
27
- position_step: int = 4,
28
- position_split_axes: bool = True,
29
- add_cls_token: bool = False,
30
- separate_new_combo_token: bool = True,
31
- **kwargs,
32
- ):
33
- self.min_time = min_time
34
- self.max_time = max_time
35
- self.time_step = time_step
36
- self.max_distance = max_distance
37
- self.distance_step = distance_step
38
- self.position_range = position_range
39
- self.position_step = position_step
40
- self.position_split_axes = position_split_axes
41
- self.add_cls_token = add_cls_token
42
- self.separate_new_combo_token = separate_new_combo_token
43
-
44
- self.audio_bos_token = "[AUDIO_BOS]"
45
- self.audio_eos_token = "[AUDIO_EOS]"
46
- self.audio_token = "[AUDIO]"
47
-
48
- if vocab_file is None:
49
- self.vocab = self._build_vocab_from_config()
50
- else:
51
- with open(vocab_file, 'r', encoding='utf-8') as f:
52
- self.vocab = json.load(f)
53
-
54
- self.ids_to_tokens = {i: t for t, i in self.vocab.items()}
55
- super().__init__(
56
- bos_token=kwargs.pop("bos_token", "[BOS]"),
57
- eos_token=kwargs.pop("eos_token", "[EOS]"),
58
- unk_token=kwargs.pop("unk_token", "[UNK]"),
59
- sep_token=kwargs.pop("sep_token", "[SEP]"),
60
- pad_token=kwargs.pop("pad_token", "[PAD]"),
61
- cls_token=kwargs.pop("cls_token", "[CLS]"),
62
- mask_token=kwargs.pop("mask_token", "[MASK]"),
63
- additional_special_tokens=kwargs.pop("additional_special_tokens", [
64
- self.audio_bos_token,
65
- self.audio_eos_token,
66
- self.audio_token,
67
- ]),
68
- min_time=min_time,
69
- max_time=max_time,
70
- time_step=time_step,
71
- max_distance=max_distance,
72
- distance_step=distance_step,
73
- position_range=position_range,
74
- position_step=position_step,
75
- position_split_axes=position_split_axes,
76
- add_cls_token=add_cls_token,
77
- separate_new_combo_token=separate_new_combo_token,
78
- **kwargs
79
- )
80
-
81
- def _build_vocab_from_config(self):
82
- vocab = []
83
-
84
- for event_type in EventType:
85
- vocab.append(f"[{event_type.value.upper()}]")
86
-
87
- if not self.separate_new_combo_token:
88
- for event_type in EVENT_TYPES_WITH_NEW_COMBO:
89
- vocab.append(f"[{event_type.value.upper()}_NEW_COMBO]")
90
-
91
- for time in np.arange(self.min_time, self.max_time + 1e-5, self.time_step):
92
- vocab.append(f"[TIME_SHIFT_{int(time)}]")
93
-
94
- for snapping in range(0, 17):
95
- vocab.append(f"[SNAPPING_{snapping}]")
96
-
97
- for distance in range(0, self.max_distance + 1):
98
- vocab.append(f"[DISTANCE_{distance}]")
99
-
100
- if self.position_split_axes:
101
- for x in np.arange(self.position_range[0], self.position_range[1] + 1e-5, self.position_step):
102
- vocab.append(f"[POS_X_{int(x)}]")
103
- for y in np.arange(self.position_range[2], self.position_range[3] + 1e-5, self.position_step):
104
- vocab.append(f"[POS_Y_{int(y)}]")
105
- else:
106
- for x in np.arange(self.position_range[0], self.position_range[1] + 1e-5, self.position_step):
107
- for y in np.arange(self.position_range[2], self.position_range[3] + 1e-5, self.position_step):
108
- vocab.append(f"[POS_{int(x)}_{int(y)}]")
109
-
110
- for mania_column in range(1, 19):
111
- vocab.append(f"[MANIA_COLUMN_{mania_column}]")
112
-
113
- for scroll_speed in np.arange(0.0, 10.0 + 1e-5, 0.01):
114
- vocab.append(f"[SCROLL_SPEED_{scroll_speed:.2f}]")
115
-
116
- if self.separate_new_combo_token:
117
- vocab.append("[NEW_COMBO]")
118
-
119
- for hitsound in range(8):
120
- for sampleset in range(1, 4):
121
- for additions in range(1, 4):
122
- vocab.append(f"[HITSOUND_{(hitsound << 1)}_{sampleset}_{additions}]")
123
-
124
- for volume in range(101):
125
- vocab.append(f"[VOLUME_{volume}]")
126
-
127
- return {token: idx for idx, token in enumerate(vocab)}
128
-
129
- def _tokenize_time_shift(self, time: int):
130
- time = np.clip(time, self.min_time, self.max_time)
131
- time = round(time / self.time_step) * self.time_step
132
- return f"[TIME_SHIFT_{int(time)}]"
133
-
134
- def _tokenize_distance(self, distance: int):
135
- distance = np.clip(distance, 0, self.max_distance)
136
- distance = round(distance / self.distance_step) * self.distance_step
137
- return f"[DISTANCE_{distance}]"
138
-
139
- def _tokenize_position(self, pos_x: int, pos_y: int):
140
- pos_x = np.clip(pos_x, self.position_range[0], self.position_range[1])
141
- pos_y = np.clip(pos_y, self.position_range[2], self.position_range[3])
142
- pos_x = round(pos_x / self.position_step) * self.position_step
143
- pos_y = round(pos_y / self.position_step) * self.position_step
144
-
145
- if self.position_split_axes:
146
- yield f"[POS_X_{int(pos_x)}]"
147
- yield f"[POS_Y_{int(pos_y)}]"
148
- else:
149
- yield f"[POS_{int(pos_x)}_{int(pos_y)}]"
150
-
151
- def _tokenize_mania_column(self, mania_column: int):
152
- mania_column = np.clip(mania_column, 1, 18)
153
- return f"[MANIA_COLUMN_{mania_column}]"
154
-
155
- def _tokenize_scroll_speed(self, scroll_speed: float):
156
- scroll_speed = np.clip(scroll_speed, 0.0, 10.0)
157
- scroll_speed = round(scroll_speed / 0.01) * 0.01
158
- return f"[SCROLL_SPEED_{scroll_speed:.2f}]"
159
-
160
- def _tokenize_hitsound(self, hitsound: int, sampleset: int, addition: int):
161
- hitsound = np.clip(hitsound >> 1, 0, 7) << 1
162
- sampleset = np.clip(sampleset, 1, 3)
163
- addition = np.clip(addition, 1, 3)
164
- return f"[HITSOUND_{hitsound}_{sampleset}_{addition}]"
165
-
166
- def _tokenize_groups(
167
- self,
168
- groups: list[Group],
169
- window_start_ms: Optional[int] = None,
170
- **_
171
- ):
172
- window_start_ms = window_start_ms or 0
173
- tokens = []
174
- if self.add_cls_token:
175
- tokens.append(self.cls_token)
176
- tokens.append(self.bos_token)
177
-
178
- for group in groups:
179
- if group.new_combo and not self.separate_new_combo_token and group.event_type in EVENT_TYPES_WITH_NEW_COMBO:
180
- tokens.append(f"[{group.event_type.value.upper()}_NEW_COMBO]")
181
- else:
182
- tokens.append(f"[{group.event_type.value.upper()}]")
183
- if group.has_time:
184
- tokens.append(self._tokenize_time_shift(group.time - window_start_ms))
185
- if group.snapping is not None:
186
- tokens.append(f"[SNAPPING_{group.snapping}]")
187
- if group.distance is not None:
188
- tokens.append(self._tokenize_distance(group.distance))
189
- if group.x is not None and group.y is not None:
190
- tokens.extend(self._tokenize_position(group.x, group.y))
191
- if group.mania_column is not None:
192
- tokens.append(self._tokenize_mania_column(group.mania_column))
193
- if group.new_combo and self.separate_new_combo_token:
194
- tokens.append("[NEW_COMBO]")
195
- if group.scroll_speed is not None:
196
- tokens.append(self._tokenize_scroll_speed(group.scroll_speed))
197
- for h, s, a, v, in zip(
198
- group.hitsounds,
199
- group.samplesets,
200
- group.additions,
201
- group.volumes,
202
- ):
203
- tokens.append(self._tokenize_hitsound(h, s, a))
204
- tokens.append(f"[VOLUME_{v}]")
205
-
206
- tokens.append(self.eos_token)
207
- return tokens
208
-
209
- def _encode_single(
210
- self,
211
- groups: Optional[Union[list[Group]]] = None,
212
- window_start_ms: Optional[int] = None,
213
- num_audio_tokens: Optional[int] = None,
214
- ):
215
- token_strings = self._tokenize_groups(groups, window_start_ms=window_start_ms)
216
- token_ids = self.convert_tokens_to_ids(token_strings)
217
-
218
- if num_audio_tokens is not None and num_audio_tokens > 0:
219
- audio_tokens = [self.audio_bos_token] + [self.audio_token] * num_audio_tokens + [self.audio_eos_token]
220
- token_ids = self.convert_tokens_to_ids(audio_tokens) + token_ids
221
-
222
- return token_ids
223
-
224
- def __call__(
225
- self,
226
- groups: Optional[Union[list[Group], list[list[Group]]]] = None,
227
- window_start_ms: Optional[Union[int, list[int]]] = None,
228
- num_audio_tokens: Optional[Union[int, list[int]]] = None,
229
- padding: PaddingStrategy = PaddingStrategy.LONGEST,
230
- truncation: TruncationStrategy = TruncationStrategy.LONGEST_FIRST,
231
- **kwargs
232
- ) -> BatchEncoding:
233
- if len(groups) == 0:
234
- raise ValueError("Input groups list is empty.")
235
-
236
- if isinstance(groups, list) and all(isinstance(g, Group) for g in groups):
237
- token_ids = self._encode_single(
238
- groups=groups,
239
- window_start_ms=window_start_ms,
240
- num_audio_tokens=num_audio_tokens,
241
- )
242
- encoding = self.prepare_for_model(
243
- token_ids,
244
- padding=padding,
245
- truncation=truncation,
246
- **kwargs,
247
- )
248
- elif isinstance(groups, list):
249
- if num_audio_tokens is None:
250
- num_audio_tokens = [None] * len(groups)
251
-
252
- if window_start_ms is None:
253
- window_start_ms = [None] * len(groups)
254
-
255
- if len(groups) != len(num_audio_tokens):
256
- raise ValueError("Number of num_audio_tokens inputs must match the number of sequences.")
257
-
258
- if len(window_start_ms) != len(groups):
259
- raise ValueError("Number of window start times must match the number of sequences.")
260
-
261
- all_token_ids = []
262
- for g, w, a in zip(groups, window_start_ms, num_audio_tokens):
263
- token_ids = self._encode_single(
264
- groups=g,
265
- window_start_ms=w,
266
- num_audio_tokens=a,
267
- )
268
- all_token_ids.append((token_ids, None))
269
-
270
- encoding = self._batch_prepare_for_model(
271
- all_token_ids,
272
- padding_strategy=PaddingStrategy(padding),
273
- truncation_strategy=TruncationStrategy(truncation),
274
- **kwargs,
275
- )
276
- else:
277
- raise ValueError("Input must be a list of Group objects or a single Group object.")
278
-
279
- return encoding
280
-
281
- @property
282
- def vocab_size(self):
283
- return len(self.vocab) + len(self._added_tokens_encoder)
284
-
285
- def get_vocab(self):
286
- return self.vocab | self._added_tokens_encoder
287
-
288
- def _convert_token_to_id(self, token):
289
- return self.vocab.get(token, self.vocab.get(self.unk_token))
290
-
291
- def _convert_id_to_token(self, index):
292
- return self.ids_to_tokens.get(index, self.unk_token)
293
-
294
- def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> tuple[str]:
295
- if not save_directory:
296
- raise ValueError("The save_directory must be specified.")
297
-
298
- vocab_file = f"{save_directory}/{filename_prefix or ''}vocab.json"
299
- with open(vocab_file, 'w', encoding='utf-8') as f:
300
- json.dump(self.vocab, f, ensure_ascii=False)
301
-
302
- return (vocab_file,)
303
-
304
-
305
- class CM3PMetadata(TypedDict, total=False):
306
- """
307
- Metadata fields for a beatmap.
308
-
309
- difficulty: Star rating, unitless (osu! difficulty)
310
- year: Year of beatmap creation (YYYY)
311
- mode: Game mode ID or name (e.g., "osu", "mania")
312
- mapper: Beatmap creator's ID or username
313
- cs: Circle size (osu!std), unitless
314
- hitsounded: Whether the beatmap is hitsounded (True/False)
315
- song_length: Song length in seconds
316
- song_position: Relative position in song [0.0-1.0], unitless
317
- global_sv: Global scroll velocity (osu!mania), multiplier
318
- mania_keycount: Number of keys in osu!mania [1-18]
319
- hold_note_ratio: Ratio of hold notes [0.0-1.0], unitless
320
- scroll_speed_ratio: Ratio of scroll speed changes [0.0-1.0], unitless
321
- tags: List of beatmap tag IDs or names
322
- """
323
- difficulty: float # Star rating, unitless (osu! difficulty)
324
- year: int # Year of beatmap creation (YYYY)
325
- mode: Union[int, str] # Game mode ID or name (e.g., "osu", "mania")
326
- status: Union[int, str] # Beatmap status (e.g., "ranked", "approved", "loved", "pending", "graveyard")
327
- mapper: Union[int, str] # Beatmap creator's ID or username
328
- cs: float # Circle size (osu!std), unitless
329
- hitsounded: bool # Whether the beatmap is hitsounded (True/False)
330
- song_length: float # Song length in seconds
331
- song_position: float # Relative position in song [0.0-1.0], unitless
332
- global_sv: float # Global slider velocity (osu!standard/catch), multiplier
333
- mania_keycount: int # Number of keys in osu!mania [1-18]
334
- hold_note_ratio: float # Ratio of hold notes [0.0-1.0], unitless
335
- scroll_speed_ratio: float # Ratio of scroll speed changes [0.0-1.0], unitless
336
- tags: list[Union[int, str]] # List of beatmap tag IDs or names
337
-
338
-
339
- def merge_metadata_dicts(m1, m2):
340
- if m1 is None:
341
- return m2
342
- if m2 is None:
343
- return m1
344
- merged = {}
345
- for key in CM3PMetadata.__annotations__.keys():
346
- v1 = m1.get(key, None)
347
- v2 = m2.get(key, None)
348
- merged[key] = v2 if v1 is None else v1
349
- return CM3PMetadata(**merged)
350
-
351
-
352
- class CM3PMetadataTokenizer(PreTrainedTokenizer):
353
- model_input_names: list[str] = ["input_ids", "attention_mask"]
354
- vocab_files_names: dict[str, str] = {"vocab_file": "vocab.json"}
355
-
356
- def __init__(
357
- self,
358
- vocab_file: Optional[str] = None,
359
- modes: Optional[dict[int, str]] = None,
360
- statuses: Optional[dict[int, str]] = None,
361
- mappers: Optional[dict[int, str]] = None,
362
- tags: Optional[dict[int, dict]] = None,
363
- min_difficculty: float = 0.0,
364
- max_difficulty: float = 14.0,
365
- difficulty_step: float = 0.1,
366
- min_year: int = 2000,
367
- max_year: int = 2023,
368
- max_song_length: int = 600,
369
- song_length_step: int = 10,
370
- song_position_step: float = 0.01,
371
- global_sv_step: float = 0.01,
372
- hold_note_ratio_step: float = 0.1,
373
- scroll_speed_ratio_step: float = 0.1,
374
- add_cls_token: bool = False,
375
- **kwargs,
376
- ):
377
- self.min_difficulty = min_difficculty
378
- self.max_difficulty = max_difficulty
379
- self.difficulty_step = difficulty_step
380
- self.min_year = min_year
381
- self.max_year = max_year
382
- self.max_song_length = max_song_length
383
- self.song_length_step = song_length_step
384
- self.song_position_step = song_position_step
385
- self.global_sv_step = global_sv_step
386
- self.hold_note_ratio_step = hold_note_ratio_step
387
- self.scroll_speed_ratio_step = scroll_speed_ratio_step
388
- self.add_cls_token = add_cls_token
389
-
390
- self.difficulty_unk_token = "[DIFFICULTY_UNK]"
391
- self.year_unk_token = "[YEAR_UNK]"
392
- self.mode_unk_token = "[MODE_UNK]"
393
- self.status_unk_token = "[STATUS_UNK]"
394
- self.mapper_unk_token = "[MAPPER_UNK]"
395
- self.cs_unk_token = "[CS_UNK]"
396
- self.hitsounded_unk_token = "[HITSOUNDED_UNK]"
397
- self.song_length_unk_token = "[SONG_LENGTH_UNK]"
398
- self.song_position_unk_token = "[SONG_POSITION_UNK]"
399
- self.global_sv_unk_token = "[GLOBAL_SV_UNK]"
400
- self.mania_keycount_unk_token = "[MANIA_KEYCOUNT_UNK]"
401
- self.hold_note_ratio_unk_token = "[HOLD_NOTE_RATIO_UNK]"
402
- self.scroll_speed_ratio_unk_token = "[SCROLL_SPEED_RATIO_UNK]"
403
- self.tag_unk_token = "[TAG_UNK]"
404
-
405
- self.modes = modes or {}
406
- self.statuses = statuses or {}
407
- self.mappers = mappers or {}
408
- self.tags = tags or {}
409
- self.mode_names_to_ids = {v: k for k, v in self.modes.items()}
410
- self.mode_ids_to_names = self.modes
411
- self.status_names_to_ids = {v: k for k, v in self.statuses.items()}
412
- self.status_ids_to_names = self.statuses
413
- self.mapper_names_to_ids = {v: k for k, v in self.mappers.items()}
414
- self.mapper_ids_to_names = self.mappers
415
- self.tag_names_to_ids = {v['name']: k for k, v in self.tags.items()}
416
- self.tag_ids_to_names = {k: v['name'] for k, v in self.tags.items()}
417
-
418
- if vocab_file is None:
419
- self.vocab = self._build_vocab_from_config()
420
- else:
421
- with open(vocab_file, 'r', encoding='utf-8') as f:
422
- self.vocab = json.load(f)
423
-
424
- self.ids_to_tokens = {i: t for t, i in self.vocab.items()}
425
-
426
- super().__init__(
427
- bos_token=kwargs.pop("bos_token", "[BOS]"),
428
- eos_token=kwargs.pop("eos_token", "[EOS]"),
429
- pad_token=kwargs.pop("pad_token", "[PAD]"),
430
- cls_token=kwargs.pop("cls_token", "[CLS]"),
431
- additional_special_tokens=kwargs.pop("additional_special_tokens", [
432
- self.difficulty_unk_token,
433
- self.year_unk_token,
434
- self.mode_unk_token,
435
- self.status_unk_token,
436
- self.mapper_unk_token,
437
- self.cs_unk_token,
438
- self.hitsounded_unk_token,
439
- self.song_length_unk_token,
440
- self.song_position_unk_token,
441
- self.global_sv_unk_token,
442
- self.mania_keycount_unk_token,
443
- self.hold_note_ratio_unk_token,
444
- self.scroll_speed_ratio_unk_token,
445
- self.tag_unk_token,
446
- ]),
447
- modes=modes,
448
- statuses=statuses,
449
- mappers=mappers,
450
- tags=tags,
451
- min_difficculty=min_difficculty,
452
- max_difficulty=max_difficulty,
453
- difficulty_step=difficulty_step,
454
- min_year=min_year,
455
- max_year=max_year,
456
- max_song_length=max_song_length,
457
- song_length_step=song_length_step,
458
- song_position_step=song_position_step,
459
- global_sv_step=global_sv_step,
460
- hold_note_ratio_step=hold_note_ratio_step,
461
- scroll_speed_ratio_step=scroll_speed_ratio_step,
462
- add_cls_token=add_cls_token,
463
- **kwargs
464
- )
465
-
466
- def _build_vocab_from_config(self):
467
- vocab = []
468
-
469
- for difficulty in np.arange(self.min_difficulty, self.max_difficulty + 1e-5, self.difficulty_step):
470
- vocab.append(f"[DIFFICULTY_{difficulty:.1f}]")
471
-
472
- for year in range(self.min_year, self.max_year + 1):
473
- vocab.append(f"[YEAR_{year}]")
474
-
475
- for mode in self.mode_ids_to_names.values():
476
- vocab.append(f"[MODE_{str(mode)}]")
477
-
478
- for status in self.status_ids_to_names.values():
479
- vocab.append(f"[STATUS_{str(status)}]")
480
-
481
- for mapper in self.mapper_ids_to_names.keys():
482
- vocab.append(f"[MAPPER_{str(mapper)}]")
483
-
484
- for cs in np.arange(0.0, 10.0 + 1e-5, 0.1):
485
- vocab.append(f"[CS_{cs:.1f}]")
486
-
487
- for hitsounded in [True, False]:
488
- vocab.append(f"[HITSOUNDED_{str(hitsounded).upper()}]")
489
-
490
- for song_length in np.arange(0, self.max_song_length + 1e-5, self.song_length_step):
491
- vocab.append(f"[SONG_LENGTH_{int(song_length)}]")
492
-
493
- for song_position in np.arange(0.0, 1.0 + 1e-5, self.song_position_step):
494
- vocab.append(f"[SONG_POSITION_{song_position:.2f}]")
495
-
496
- for global_sv in np.arange(0.4, 3.6 + 1e-5, self.global_sv_step):
497
- vocab.append(f"[GLOBAL_SV_{global_sv:.2f}]")
498
-
499
- for mania_keycount in range(1, 19):
500
- vocab.append(f"[MANIA_KEYCOUNT_{mania_keycount}]")
501
-
502
- for hold_note_ratio in np.arange(0.0, 1.0 + 1e-5, self.hold_note_ratio_step):
503
- vocab.append(f"[HOLD_NOTE_RATIO_{hold_note_ratio:.1f}]")
504
-
505
- for scroll_speed_ratio in np.arange(0.0, 1.0 + 1e-5, self.scroll_speed_ratio_step):
506
- vocab.append(f"[SCROLL_SPEED_RATIO_{scroll_speed_ratio:.1f}]")
507
-
508
- for tag in self.tag_ids_to_names.values():
509
- vocab.append(f"[TAG_{tag}]")
510
-
511
- return {token: idx for idx, token in enumerate(vocab)}
512
-
513
- def _tokenize_difficulty(self, metadata: CM3PMetadata):
514
- difficulty = metadata.get('difficulty', None)
515
- if difficulty is None:
516
- return self.difficulty_unk_token
517
- difficulty = np.clip(difficulty, self.min_difficulty, self.max_difficulty)
518
- difficulty = round(difficulty / self.difficulty_step) * self.difficulty_step
519
- return f"[DIFFICULTY_{difficulty:.1f}]"
520
-
521
- def _tokenize_year(self, metadata: CM3PMetadata):
522
- year = metadata.get('year', None)
523
- if year is None:
524
- return self.year_unk_token
525
- year = np.clip(year, self.min_year, self.max_year)
526
- return f"[YEAR_{year}]"
527
-
528
- def _tokenize_mode(self, metadata: CM3PMetadata):
529
- mode_str = metadata.get('mode', None)
530
- if isinstance(mode_str, int):
531
- mode_str = self.mode_ids_to_names.get(mode_str, None)
532
- if mode_str is None or mode_str not in self.mode_names_to_ids:
533
- return self.mode_unk_token
534
- return f"[MODE_{str(mode_str)}]"
535
-
536
- def _tokenize_status(self, metadata: CM3PMetadata):
537
- status_str = metadata.get('status', None)
538
- if isinstance(status_str, int):
539
- status_str = self.status_ids_to_names.get(status_str, None)
540
- if status_str is None or status_str not in self.status_names_to_ids:
541
- return self.status_unk_token
542
- return f"[STATUS_{str(status_str)}]"
543
-
544
- def _tokenize_mapper(self, metadata: CM3PMetadata):
545
- mapper_id = metadata.get('mapper', None)
546
- if isinstance(mapper_id, str):
547
- mapper_id = self.mapper_names_to_ids.get(mapper_id, None)
548
- if mapper_id is None or mapper_id not in self.mapper_ids_to_names:
549
- return self.mapper_unk_token
550
- return f"[MAPPER_{str(mapper_id)}]"
551
-
552
- def _tokenize_cs(self, metadata: CM3PMetadata):
553
- cs = metadata.get('cs', None)
554
- if cs is None:
555
- return self.cs_unk_token
556
- cs = np.clip(cs, 0.0, 10.0)
557
- cs = round(cs / 0.1) * 0.1
558
- return f"[CS_{cs:.1f}]"
559
-
560
- def _tokenize_hitsounded(self, metadata: CM3PMetadata):
561
- hitsounded = metadata.get('hitsounded', None)
562
- if hitsounded is None:
563
- return self.hitsounded_unk_token
564
- return f"[HITSOUNDED_{str(hitsounded).upper()}]"
565
-
566
- def _tokenize_song_length(self, metadata: CM3PMetadata):
567
- song_length = metadata.get('song_length', None)
568
- if song_length is None:
569
- return self.song_length_unk_token
570
- song_length = np.clip(song_length, 0, self.max_song_length)
571
- song_length = round(song_length / self.song_length_step) * self.song_length_step
572
- return f"[SONG_LENGTH_{int(song_length)}]"
573
-
574
- def _tokenize_song_position(self, metadata: CM3PMetadata):
575
- song_position = metadata.get('song_position', None)
576
- if song_position is None:
577
- return self.song_position_unk_token
578
- song_position = np.clip(song_position, 0.0, 1.0)
579
- song_position = round(song_position / self.song_position_step) * self.song_position_step
580
- return f"[SONG_POSITION_{song_position:.2f}]"
581
-
582
- def _tokenize_global_sv(self, metadata: CM3PMetadata):
583
- global_sv = metadata.get('global_sv', None)
584
- if global_sv is None:
585
- return self.global_sv_unk_token
586
- global_sv = np.clip(global_sv, 0.4, 3.6)
587
- global_sv = round(global_sv / self.global_sv_step) * self.global_sv_step
588
- return f"[GLOBAL_SV_{global_sv:.2f}]"
589
-
590
- def _tokenize_mania_keycount(self, metadata: CM3PMetadata):
591
- mania_keycount = metadata.get('mania_keycount', None)
592
- if mania_keycount is None:
593
- return self.mania_keycount_unk_token
594
- mania_keycount = int(mania_keycount)
595
- mania_keycount = np.clip(mania_keycount, 1, 18)
596
- return f"[MANIA_KEYCOUNT_{mania_keycount}]"
597
-
598
- def _tokenize_hold_note_ratio(self, metadata: CM3PMetadata):
599
- hold_note_ratio = metadata.get('hold_note_ratio', None)
600
- if hold_note_ratio is None:
601
- return self.hold_note_ratio_unk_token
602
- hold_note_ratio = np.clip(hold_note_ratio, 0.0, 1.0)
603
- hold_note_ratio = round(hold_note_ratio / self.hold_note_ratio_step) * self.hold_note_ratio_step
604
- return f"[HOLD_NOTE_RATIO_{hold_note_ratio:.1f}]"
605
-
606
- def _tokenize_scroll_speed_ratio(self, metadata: CM3PMetadata):
607
- scroll_speed_ratio = metadata.get('scroll_speed_ratio', None)
608
- if scroll_speed_ratio is None:
609
- return self.scroll_speed_ratio_unk_token
610
- scroll_speed_ratio = np.clip(scroll_speed_ratio, 0.0, 1.0)
611
- scroll_speed_ratio = round(scroll_speed_ratio / self.scroll_speed_ratio_step) * self.scroll_speed_ratio_step
612
- return f"[SCROLL_SPEED_RATIO_{scroll_speed_ratio:.1f}]"
613
-
614
- def _validate_tags(self, tags):
615
- if tags is None:
616
- return None
617
- new_tags = []
618
- for tag in tags:
619
- if isinstance(tag, str) and tag in self.tag_names_to_ids:
620
- new_tags.append(tag)
621
- elif tag in self.tag_ids_to_names:
622
- new_tags.append(self.tag_ids_to_names[tag])
623
- return new_tags
624
-
625
- def _tokenize_tags(self, metadata: CM3PMetadata):
626
- tags = metadata.get('tags', None)
627
- valid_tags = self._validate_tags(tags)
628
- if not valid_tags:
629
- return [self.tag_unk_token]
630
- return [f"[TAG_{tag}]" for tag in valid_tags]
631
-
632
- def _tokenize_metadata(self, metadata: CM3PMetadata):
633
- tokens = []
634
- if self.add_cls_token:
635
- tokens.append(self.cls_token)
636
- tokens.extend([
637
- self.bos_token,
638
- self._tokenize_difficulty(metadata),
639
- self._tokenize_year(metadata),
640
- self._tokenize_mode(metadata),
641
- self._tokenize_status(metadata),
642
- self._tokenize_mapper(metadata),
643
- self._tokenize_cs(metadata),
644
- self._tokenize_hitsounded(metadata),
645
- self._tokenize_song_length(metadata),
646
- self._tokenize_song_position(metadata),
647
- self._tokenize_global_sv(metadata),
648
- self._tokenize_mania_keycount(metadata),
649
- self._tokenize_hold_note_ratio(metadata),
650
- self._tokenize_scroll_speed_ratio(metadata),
651
- ])
652
- tokens.extend(self._tokenize_tags(metadata))
653
- tokens.append(self.eos_token)
654
- return tokens
655
-
656
- def __call__(
657
- self,
658
- metadata: Optional[Union[CM3PMetadata, list[CM3PMetadata]]] = None,
659
- padding: PaddingStrategy = PaddingStrategy.LONGEST,
660
- truncation: TruncationStrategy = TruncationStrategy.DO_NOT_TRUNCATE,
661
- max_length: Optional[int] = None,
662
- return_tensors: Optional[str] = "pt",
663
- **kwargs
664
- ) -> BatchEncoding:
665
- if isinstance(metadata, dict):
666
- token_strings = self._tokenize_metadata(metadata)
667
- token_ids = self.convert_tokens_to_ids(token_strings)
668
- return self.prepare_for_model(
669
- token_ids,
670
- padding=padding,
671
- truncation=truncation,
672
- max_length=max_length,
673
- return_tensors=return_tensors,
674
- **kwargs,
675
- )
676
- elif isinstance(metadata, list):
677
- all_token_ids = []
678
- for m in metadata:
679
- token_strings = self._tokenize_metadata(m)
680
- token_ids = self.convert_tokens_to_ids(token_strings)
681
- all_token_ids.append((token_ids, None))
682
-
683
- return self._batch_prepare_for_model(
684
- all_token_ids,
685
- padding_strategy=PaddingStrategy(padding),
686
- truncation_strategy=TruncationStrategy(truncation),
687
- max_length=max_length,
688
- return_tensors=return_tensors,
689
- )
690
-
691
- def metadata_variations(self, metadata: CM3PMetadata, num_variations: int = 1000) -> tuple[CM3PMetadata, int]:
692
- def year_variations():
693
- min_year = max(2007, self.min_year)
694
- if metadata["year"] is None or (min_year > metadata["year"] or metadata["year"] > self.max_year):
695
- return
696
- for year in range(min_year, self.max_year + 1):
697
- if year != metadata["year"]:
698
- new_m = copy.deepcopy(metadata)
699
- new_m["year"] = year
700
- yield new_m, 1
701
-
702
- def status_variations():
703
- if metadata["status"] is None:
704
- return
705
- current_status = self.status_ids_to_names.get(metadata["status"], None) or metadata["status"]
706
- if current_status not in self.status_names_to_ids:
707
- return
708
- for status in self.status_ids_to_names.values():
709
- if status != current_status:
710
- new_m = copy.deepcopy(metadata)
711
- new_m["status"] = status
712
- yield new_m, 2
713
-
714
- def tags_variations():
715
- # Replace/add/remove some tags
716
- if metadata["tags"] is None or len(metadata["tags"]) <= 0:
717
- return
718
- current_tags = self._validate_tags(metadata["tags"])
719
- if len(current_tags) <= 0:
720
- return
721
- for tag in self.tag_ids_to_names.values():
722
- if tag not in current_tags:
723
- new_m = copy.deepcopy(metadata)
724
- new_m["tags"][np.random.randint(0, len(new_m["tags"]))] = tag
725
- yield new_m, 3
726
- for tag in self.tag_ids_to_names.values():
727
- if tag not in current_tags:
728
- new_m = copy.deepcopy(metadata)
729
- new_m["tags"].insert(np.random.randint(0, len(new_m["tags"]) + 1), tag)
730
- yield new_m, 3
731
- if len(current_tags) <= 1:
732
- return
733
- for tag in current_tags:
734
- new_m = copy.deepcopy(metadata)
735
- new_tags = [t for t in current_tags if t != tag]
736
- new_m["tags"] = new_tags
737
- yield new_m, 3
738
-
739
- def mapper_variations():
740
- if metadata['mapper'] is None:
741
- return
742
- current_mapper = self.mapper_names_to_ids.get(metadata["mapper"], None) or metadata["mapper"]
743
- mapper_variations = list(self.mapper_ids_to_names.keys())
744
- if current_mapper in self.mapper_ids_to_names:
745
- mapper_variations.remove(current_mapper)
746
- # Randomly sample mappers to avoid too many variations
747
- np.random.shuffle(mapper_variations)
748
- for mapper in mapper_variations:
749
- new_m = copy.deepcopy(metadata)
750
- new_m["mapper"] = mapper
751
- yield new_m, 4
752
-
753
- def padding_variations():
754
- while True:
755
- yield CM3PMetadata(), -1
756
-
757
- # Add variations with one field changed at a time
758
- current_num_variations = 0
759
- workers = [
760
- year_variations(),
761
- status_variations(),
762
- tags_variations(),
763
- mapper_variations(),
764
- ]
765
- padding_iterable = padding_variations()
766
-
767
- index = 0
768
- while current_num_variations < num_variations and len(workers) > 0:
769
- try:
770
- index = index % len(workers)
771
- item = workers[index].__next__()
772
- index += 1
773
- current_num_variations += 1
774
- yield item
775
- except StopIteration:
776
- workers.remove(workers[index])
777
-
778
- while current_num_variations < num_variations:
779
- current_num_variations += 1
780
- yield padding_iterable.__next__()
781
-
782
- @property
783
- def vocab_size(self):
784
- return len(self.vocab) + len(self._added_tokens_encoder)
785
-
786
- def get_vocab(self):
787
- return self.vocab | self._added_tokens_encoder
788
-
789
- def _convert_token_to_id(self, token):
790
- return self.vocab.get(token, self.vocab.get(self.unk_token))
791
-
792
- def _convert_id_to_token(self, index):
793
- return self.ids_to_tokens.get(index, self.unk_token)
794
-
795
- def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> tuple[str]:
796
- if not save_directory:
797
- raise ValueError("The save_directory must be specified.")
798
-
799
- vocab_file = f"{save_directory}/{filename_prefix or ''}vocab.json"
800
- with open(vocab_file, 'w', encoding='utf-8') as f:
801
- json.dump(self.vocab, f, ensure_ascii=False)
802
-
803
- return (vocab_file,)
804
-
805
- AutoTokenizer.register(CM3PBeatmapConfig, CM3PBeatmapTokenizer)
806
- AutoTokenizer.register(CM3PMetadataConfig, CM3PMetadataTokenizer)
807
-
808
- __all__ = ["CM3PBeatmapTokenizer", "CM3PMetadataTokenizer", "CM3PMetadata", "merge_metadata_dicts"]
 
1
+ import copy
2
+ import json
3
+ from typing import Optional, Union, TypedDict
4
+
5
+ import numpy as np
6
+ from transformers import PreTrainedTokenizer, BatchEncoding, AutoTokenizer
7
+ from transformers.tokenization_utils_base import TruncationStrategy
8
+ from transformers.utils import PaddingStrategy
9
+
10
+ from .configuration_cm3p import CM3PBeatmapConfig, CM3PMetadataConfig
11
+ from .parsing_cm3p import Group, EventType, EVENT_TYPES_WITH_NEW_COMBO
12
+
13
+
14
+ class CM3PBeatmapTokenizer(PreTrainedTokenizer):
15
+ model_input_names: list[str] = ["input_ids", "attention_mask"]
16
+ vocab_files_names: dict[str, str] = {"vocab_file": "vocab.json"}
17
+
18
+ def __init__(
19
+ self,
20
+ vocab_file: Optional[str] = None,
21
+ min_time: int = 0,
22
+ max_time: int = 30000,
23
+ time_step: int = 10,
24
+ max_distance: int = 640,
25
+ distance_step: int = 4,
26
+ position_range: tuple[int, int, int, int] = (-256, 768, -256, 640),
27
+ position_step: int = 4,
28
+ position_split_axes: bool = True,
29
+ add_cls_token: bool = False,
30
+ separate_new_combo_token: bool = True,
31
+ **kwargs,
32
+ ):
33
+ self.min_time = min_time
34
+ self.max_time = max_time
35
+ self.time_step = time_step
36
+ self.max_distance = max_distance
37
+ self.distance_step = distance_step
38
+ self.position_range = position_range
39
+ self.position_step = position_step
40
+ self.position_split_axes = position_split_axes
41
+ self.add_cls_token = add_cls_token
42
+ self.separate_new_combo_token = separate_new_combo_token
43
+
44
+ self.audio_bos_token = "[AUDIO_BOS]"
45
+ self.audio_eos_token = "[AUDIO_EOS]"
46
+ self.audio_token = "[AUDIO]"
47
+
48
+ if vocab_file is None:
49
+ self.vocab = self._build_vocab_from_config()
50
+ else:
51
+ with open(vocab_file, 'r', encoding='utf-8') as f:
52
+ self.vocab = json.load(f)
53
+
54
+ self.ids_to_tokens = {i: t for t, i in self.vocab.items()}
55
+ super().__init__(
56
+ bos_token=kwargs.pop("bos_token", "[BOS]"),
57
+ eos_token=kwargs.pop("eos_token", "[EOS]"),
58
+ unk_token=kwargs.pop("unk_token", "[UNK]"),
59
+ sep_token=kwargs.pop("sep_token", "[SEP]"),
60
+ pad_token=kwargs.pop("pad_token", "[PAD]"),
61
+ cls_token=kwargs.pop("cls_token", "[CLS]"),
62
+ mask_token=kwargs.pop("mask_token", "[MASK]"),
63
+ additional_special_tokens=kwargs.pop("additional_special_tokens", [
64
+ self.audio_bos_token,
65
+ self.audio_eos_token,
66
+ self.audio_token,
67
+ ]),
68
+ min_time=min_time,
69
+ max_time=max_time,
70
+ time_step=time_step,
71
+ max_distance=max_distance,
72
+ distance_step=distance_step,
73
+ position_range=position_range,
74
+ position_step=position_step,
75
+ position_split_axes=position_split_axes,
76
+ add_cls_token=add_cls_token,
77
+ separate_new_combo_token=separate_new_combo_token,
78
+ **kwargs
79
+ )
80
+
81
+ def _build_vocab_from_config(self):
82
+ vocab = []
83
+
84
+ for event_type in EventType:
85
+ vocab.append(f"[{event_type.value.upper()}]")
86
+
87
+ if not self.separate_new_combo_token:
88
+ for event_type in EVENT_TYPES_WITH_NEW_COMBO:
89
+ vocab.append(f"[{event_type.value.upper()}_NEW_COMBO]")
90
+
91
+ for time in np.arange(self.min_time, self.max_time + 1e-5, self.time_step):
92
+ vocab.append(f"[TIME_SHIFT_{int(time)}]")
93
+
94
+ for snapping in range(0, 17):
95
+ vocab.append(f"[SNAPPING_{snapping}]")
96
+
97
+ for distance in range(0, self.max_distance + 1):
98
+ vocab.append(f"[DISTANCE_{distance}]")
99
+
100
+ if self.position_split_axes:
101
+ for x in np.arange(self.position_range[0], self.position_range[1] + 1e-5, self.position_step):
102
+ vocab.append(f"[POS_X_{int(x)}]")
103
+ for y in np.arange(self.position_range[2], self.position_range[3] + 1e-5, self.position_step):
104
+ vocab.append(f"[POS_Y_{int(y)}]")
105
+ else:
106
+ for x in np.arange(self.position_range[0], self.position_range[1] + 1e-5, self.position_step):
107
+ for y in np.arange(self.position_range[2], self.position_range[3] + 1e-5, self.position_step):
108
+ vocab.append(f"[POS_{int(x)}_{int(y)}]")
109
+
110
+ for mania_column in range(1, 19):
111
+ vocab.append(f"[MANIA_COLUMN_{mania_column}]")
112
+
113
+ for scroll_speed in np.arange(0.0, 10.0 + 1e-5, 0.01):
114
+ vocab.append(f"[SCROLL_SPEED_{scroll_speed:.2f}]")
115
+
116
+ if self.separate_new_combo_token:
117
+ vocab.append("[NEW_COMBO]")
118
+
119
+ for hitsound in range(8):
120
+ for sampleset in range(1, 4):
121
+ for additions in range(1, 4):
122
+ vocab.append(f"[HITSOUND_{(hitsound << 1)}_{sampleset}_{additions}]")
123
+
124
+ for volume in range(101):
125
+ vocab.append(f"[VOLUME_{volume}]")
126
+
127
+ return {token: idx for idx, token in enumerate(vocab)}
128
+
129
+ def _tokenize_time_shift(self, time: int):
130
+ time = np.clip(time, self.min_time, self.max_time)
131
+ time = round(time / self.time_step) * self.time_step
132
+ return f"[TIME_SHIFT_{int(time)}]"
133
+
134
+ def _tokenize_distance(self, distance: int):
135
+ distance = np.clip(distance, 0, self.max_distance)
136
+ distance = round(distance / self.distance_step) * self.distance_step
137
+ return f"[DISTANCE_{distance}]"
138
+
139
+ def _tokenize_position(self, pos_x: int, pos_y: int):
140
+ pos_x = np.clip(pos_x, self.position_range[0], self.position_range[1])
141
+ pos_y = np.clip(pos_y, self.position_range[2], self.position_range[3])
142
+ pos_x = round(pos_x / self.position_step) * self.position_step
143
+ pos_y = round(pos_y / self.position_step) * self.position_step
144
+
145
+ if self.position_split_axes:
146
+ yield f"[POS_X_{int(pos_x)}]"
147
+ yield f"[POS_Y_{int(pos_y)}]"
148
+ else:
149
+ yield f"[POS_{int(pos_x)}_{int(pos_y)}]"
150
+
151
+ def _tokenize_mania_column(self, mania_column: int):
152
+ mania_column = np.clip(mania_column, 1, 18)
153
+ return f"[MANIA_COLUMN_{mania_column}]"
154
+
155
+ def _tokenize_scroll_speed(self, scroll_speed: float):
156
+ scroll_speed = np.clip(scroll_speed, 0.0, 10.0)
157
+ scroll_speed = round(scroll_speed / 0.01) * 0.01
158
+ return f"[SCROLL_SPEED_{scroll_speed:.2f}]"
159
+
160
+ def _tokenize_hitsound(self, hitsound: int, sampleset: int, addition: int):
161
+ hitsound = np.clip(hitsound >> 1, 0, 7) << 1
162
+ sampleset = np.clip(sampleset, 1, 3)
163
+ addition = np.clip(addition, 1, 3)
164
+ return f"[HITSOUND_{hitsound}_{sampleset}_{addition}]"
165
+
166
+ def _tokenize_groups(
167
+ self,
168
+ groups: list[Group],
169
+ window_start_ms: Optional[int] = None,
170
+ **_
171
+ ):
172
+ window_start_ms = window_start_ms or 0
173
+ tokens = []
174
+ if self.add_cls_token:
175
+ tokens.append(self.cls_token)
176
+ tokens.append(self.bos_token)
177
+
178
+ for group in groups:
179
+ if group.new_combo and not self.separate_new_combo_token and group.event_type in EVENT_TYPES_WITH_NEW_COMBO:
180
+ tokens.append(f"[{group.event_type.value.upper()}_NEW_COMBO]")
181
+ else:
182
+ tokens.append(f"[{group.event_type.value.upper()}]")
183
+ if group.has_time:
184
+ tokens.append(self._tokenize_time_shift(group.time - window_start_ms))
185
+ if group.snapping is not None:
186
+ tokens.append(f"[SNAPPING_{group.snapping}]")
187
+ if group.distance is not None:
188
+ tokens.append(self._tokenize_distance(group.distance))
189
+ if group.x is not None and group.y is not None:
190
+ tokens.extend(self._tokenize_position(group.x, group.y))
191
+ if group.mania_column is not None:
192
+ tokens.append(self._tokenize_mania_column(group.mania_column))
193
+ if group.new_combo and self.separate_new_combo_token:
194
+ tokens.append("[NEW_COMBO]")
195
+ if group.scroll_speed is not None:
196
+ tokens.append(self._tokenize_scroll_speed(group.scroll_speed))
197
+ for h, s, a, v, in zip(
198
+ group.hitsounds,
199
+ group.samplesets,
200
+ group.additions,
201
+ group.volumes,
202
+ ):
203
+ tokens.append(self._tokenize_hitsound(h, s, a))
204
+ tokens.append(f"[VOLUME_{v}]")
205
+
206
+ tokens.append(self.eos_token)
207
+ return tokens
208
+
209
+ def _encode_single(
210
+ self,
211
+ groups: Optional[Union[list[Group]]] = None,
212
+ window_start_ms: Optional[int] = None,
213
+ num_audio_tokens: Optional[int] = None,
214
+ ):
215
+ token_strings = self._tokenize_groups(groups, window_start_ms=window_start_ms)
216
+ token_ids = self.convert_tokens_to_ids(token_strings)
217
+
218
+ if num_audio_tokens is not None and num_audio_tokens > 0:
219
+ audio_tokens = [self.audio_bos_token] + [self.audio_token] * num_audio_tokens + [self.audio_eos_token]
220
+ token_ids = self.convert_tokens_to_ids(audio_tokens) + token_ids
221
+
222
+ return token_ids
223
+
224
+ def __call__(
225
+ self,
226
+ groups: Optional[Union[list[Group], list[list[Group]]]] = None,
227
+ window_start_ms: Optional[Union[int, list[int]]] = None,
228
+ num_audio_tokens: Optional[Union[int, list[int]]] = None,
229
+ padding: PaddingStrategy = PaddingStrategy.LONGEST,
230
+ truncation: TruncationStrategy = TruncationStrategy.LONGEST_FIRST,
231
+ **kwargs
232
+ ) -> BatchEncoding:
233
+ if len(groups) == 0:
234
+ raise ValueError("Input groups list is empty.")
235
+
236
+ if isinstance(groups, list) and all(isinstance(g, Group) for g in groups):
237
+ token_ids = self._encode_single(
238
+ groups=groups,
239
+ window_start_ms=window_start_ms,
240
+ num_audio_tokens=num_audio_tokens,
241
+ )
242
+ encoding = self.prepare_for_model(
243
+ token_ids,
244
+ padding=padding,
245
+ truncation=truncation,
246
+ **kwargs,
247
+ )
248
+ elif isinstance(groups, list):
249
+ if num_audio_tokens is None:
250
+ num_audio_tokens = [None] * len(groups)
251
+
252
+ if window_start_ms is None:
253
+ window_start_ms = [None] * len(groups)
254
+
255
+ if len(groups) != len(num_audio_tokens):
256
+ raise ValueError("Number of num_audio_tokens inputs must match the number of sequences.")
257
+
258
+ if len(window_start_ms) != len(groups):
259
+ raise ValueError("Number of window start times must match the number of sequences.")
260
+
261
+ all_token_ids = []
262
+ for g, w, a in zip(groups, window_start_ms, num_audio_tokens):
263
+ token_ids = self._encode_single(
264
+ groups=g,
265
+ window_start_ms=w,
266
+ num_audio_tokens=a,
267
+ )
268
+ all_token_ids.append((token_ids, None))
269
+
270
+ encoding = self._batch_prepare_for_model(
271
+ all_token_ids,
272
+ padding_strategy=PaddingStrategy(padding),
273
+ truncation_strategy=TruncationStrategy(truncation),
274
+ **kwargs,
275
+ )
276
+ else:
277
+ raise ValueError("Input must be a list of Group objects or a single Group object.")
278
+
279
+ return encoding
280
+
281
+ @property
282
+ def vocab_size(self):
283
+ return len(self.vocab) + len(self._added_tokens_encoder)
284
+
285
+ def get_vocab(self):
286
+ return self.vocab | self._added_tokens_encoder
287
+
288
+ def _convert_token_to_id(self, token):
289
+ return self.vocab.get(token, self.vocab.get(self.unk_token))
290
+
291
+ def _convert_id_to_token(self, index):
292
+ return self.ids_to_tokens.get(index, self.unk_token)
293
+
294
+ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> tuple[str]:
295
+ if not save_directory:
296
+ raise ValueError("The save_directory must be specified.")
297
+
298
+ vocab_file = f"{save_directory}/{filename_prefix or ''}vocab.json"
299
+ with open(vocab_file, 'w', encoding='utf-8') as f:
300
+ json.dump(self.vocab, f, ensure_ascii=False)
301
+
302
+ return (vocab_file,)
303
+
304
+
305
+ class CM3PMetadata(TypedDict, total=False):
306
+ """
307
+ Metadata fields for a beatmap.
308
+
309
+ difficulty: Star rating, unitless (osu! difficulty)
310
+ year: Year of beatmap creation (YYYY)
311
+ mode: Game mode ID or name (e.g., "osu", "mania")
312
+ mapper: Beatmap creator's ID or username
313
+ cs: Circle size (osu!std), unitless
314
+ hitsounded: Whether the beatmap is hitsounded (True/False)
315
+ song_length: Song length in seconds
316
+ song_position: Relative position in song [0.0-1.0], unitless
317
+ global_sv: Global scroll velocity (osu!mania), multiplier
318
+ mania_keycount: Number of keys in osu!mania [1-18]
319
+ hold_note_ratio: Ratio of hold notes [0.0-1.0], unitless
320
+ scroll_speed_ratio: Ratio of scroll speed changes [0.0-1.0], unitless
321
+ tags: List of beatmap tag IDs or names
322
+ """
323
+ difficulty: float # Star rating, unitless (osu! difficulty)
324
+ year: int # Year of beatmap creation (YYYY)
325
+ mode: Union[int, str] # Game mode ID or name (e.g., "osu", "mania")
326
+ status: Union[int, str] # Beatmap status (e.g., "ranked", "approved", "loved", "pending", "graveyard")
327
+ mapper: Union[int, str] # Beatmap creator's ID or username
328
+ cs: float # Circle size (osu!std), unitless
329
+ hitsounded: bool # Whether the beatmap is hitsounded (True/False)
330
+ song_length: float # Song length in seconds
331
+ song_position: float # Relative position in song [0.0-1.0], unitless
332
+ global_sv: float # Global slider velocity (osu!standard/catch), multiplier
333
+ mania_keycount: int # Number of keys in osu!mania [1-18]
334
+ hold_note_ratio: float # Ratio of hold notes [0.0-1.0], unitless
335
+ scroll_speed_ratio: float # Ratio of scroll speed changes [0.0-1.0], unitless
336
+ tags: list[Union[int, str]] # List of beatmap tag IDs or names
337
+
338
+
339
+ def merge_metadata_dicts(m1, m2):
340
+ if m1 is None:
341
+ return m2
342
+ if m2 is None:
343
+ return m1
344
+ merged = {}
345
+ for key in CM3PMetadata.__annotations__.keys():
346
+ v1 = m1.get(key, None)
347
+ v2 = m2.get(key, None)
348
+ merged[key] = v2 if v1 is None else v1
349
+ return CM3PMetadata(**merged)
350
+
351
+
352
+ class CM3PMetadataTokenizer(PreTrainedTokenizer):
353
+ model_input_names: list[str] = ["input_ids", "attention_mask"]
354
+ vocab_files_names: dict[str, str] = {"vocab_file": "vocab.json"}
355
+
356
+ def __init__(
357
+ self,
358
+ vocab_file: Optional[str] = None,
359
+ modes: Optional[dict[int, str]] = None,
360
+ statuses: Optional[dict[int, str]] = None,
361
+ mappers: Optional[dict[int, str]] = None,
362
+ tags: Optional[dict[int, dict]] = None,
363
+ min_difficculty: float = 0.0,
364
+ max_difficulty: float = 14.0,
365
+ difficulty_step: float = 0.1,
366
+ min_year: int = 2000,
367
+ max_year: int = 2023,
368
+ max_song_length: int = 600,
369
+ song_length_step: int = 10,
370
+ song_position_step: float = 0.01,
371
+ global_sv_step: float = 0.01,
372
+ hold_note_ratio_step: float = 0.1,
373
+ scroll_speed_ratio_step: float = 0.1,
374
+ add_cls_token: bool = False,
375
+ **kwargs,
376
+ ):
377
+ self.min_difficulty = min_difficculty
378
+ self.max_difficulty = max_difficulty
379
+ self.difficulty_step = difficulty_step
380
+ self.min_year = min_year
381
+ self.max_year = max_year
382
+ self.max_song_length = max_song_length
383
+ self.song_length_step = song_length_step
384
+ self.song_position_step = song_position_step
385
+ self.global_sv_step = global_sv_step
386
+ self.hold_note_ratio_step = hold_note_ratio_step
387
+ self.scroll_speed_ratio_step = scroll_speed_ratio_step
388
+ self.add_cls_token = add_cls_token
389
+
390
+ self.difficulty_unk_token = "[DIFFICULTY_UNK]"
391
+ self.year_unk_token = "[YEAR_UNK]"
392
+ self.mode_unk_token = "[MODE_UNK]"
393
+ self.status_unk_token = "[STATUS_UNK]"
394
+ self.mapper_unk_token = "[MAPPER_UNK]"
395
+ self.cs_unk_token = "[CS_UNK]"
396
+ self.hitsounded_unk_token = "[HITSOUNDED_UNK]"
397
+ self.song_length_unk_token = "[SONG_LENGTH_UNK]"
398
+ self.song_position_unk_token = "[SONG_POSITION_UNK]"
399
+ self.global_sv_unk_token = "[GLOBAL_SV_UNK]"
400
+ self.mania_keycount_unk_token = "[MANIA_KEYCOUNT_UNK]"
401
+ self.hold_note_ratio_unk_token = "[HOLD_NOTE_RATIO_UNK]"
402
+ self.scroll_speed_ratio_unk_token = "[SCROLL_SPEED_RATIO_UNK]"
403
+ self.tag_unk_token = "[TAG_UNK]"
404
+
405
+ self.modes = modes or {}
406
+ self.statuses = statuses or {}
407
+ self.mappers = mappers or {}
408
+ self.tags = tags or {}
409
+ self.mode_names_to_ids = {v: k for k, v in self.modes.items()}
410
+ self.mode_ids_to_names = {int(k): v for k, v in self.modes.items()}
411
+ self.status_names_to_ids = {v: k for k, v in self.statuses.items()}
412
+ self.status_ids_to_names = {int(k): v for k, v in self.statuses.items()}
413
+ self.mapper_names_to_ids = {v: k for k, v in self.mappers.items()}
414
+ self.mapper_ids_to_names = {int(k): v for k, v in self.mappers.items()}
415
+ self.tag_names_to_ids = {v['name']: k for k, v in self.tags.items()}
416
+ self.tag_ids_to_names = {int(k): v['name'] for k, v in self.tags.items()}
417
+
418
+ if vocab_file is None:
419
+ self.vocab = self._build_vocab_from_config()
420
+ else:
421
+ with open(vocab_file, 'r', encoding='utf-8') as f:
422
+ self.vocab = json.load(f)
423
+
424
+ self.ids_to_tokens = {i: t for t, i in self.vocab.items()}
425
+
426
+ super().__init__(
427
+ bos_token=kwargs.pop("bos_token", "[BOS]"),
428
+ eos_token=kwargs.pop("eos_token", "[EOS]"),
429
+ pad_token=kwargs.pop("pad_token", "[PAD]"),
430
+ cls_token=kwargs.pop("cls_token", "[CLS]"),
431
+ additional_special_tokens=kwargs.pop("additional_special_tokens", [
432
+ self.difficulty_unk_token,
433
+ self.year_unk_token,
434
+ self.mode_unk_token,
435
+ self.status_unk_token,
436
+ self.mapper_unk_token,
437
+ self.cs_unk_token,
438
+ self.hitsounded_unk_token,
439
+ self.song_length_unk_token,
440
+ self.song_position_unk_token,
441
+ self.global_sv_unk_token,
442
+ self.mania_keycount_unk_token,
443
+ self.hold_note_ratio_unk_token,
444
+ self.scroll_speed_ratio_unk_token,
445
+ self.tag_unk_token,
446
+ ]),
447
+ modes=modes,
448
+ statuses=statuses,
449
+ mappers=mappers,
450
+ tags=tags,
451
+ min_difficculty=min_difficculty,
452
+ max_difficulty=max_difficulty,
453
+ difficulty_step=difficulty_step,
454
+ min_year=min_year,
455
+ max_year=max_year,
456
+ max_song_length=max_song_length,
457
+ song_length_step=song_length_step,
458
+ song_position_step=song_position_step,
459
+ global_sv_step=global_sv_step,
460
+ hold_note_ratio_step=hold_note_ratio_step,
461
+ scroll_speed_ratio_step=scroll_speed_ratio_step,
462
+ add_cls_token=add_cls_token,
463
+ **kwargs
464
+ )
465
+
466
+ def _build_vocab_from_config(self):
467
+ vocab = []
468
+
469
+ for difficulty in np.arange(self.min_difficulty, self.max_difficulty + 1e-5, self.difficulty_step):
470
+ vocab.append(f"[DIFFICULTY_{difficulty:.1f}]")
471
+
472
+ for year in range(self.min_year, self.max_year + 1):
473
+ vocab.append(f"[YEAR_{year}]")
474
+
475
+ for mode in self.mode_ids_to_names.values():
476
+ vocab.append(f"[MODE_{str(mode)}]")
477
+
478
+ for status in self.status_ids_to_names.values():
479
+ vocab.append(f"[STATUS_{str(status)}]")
480
+
481
+ for mapper in self.mapper_ids_to_names.keys():
482
+ vocab.append(f"[MAPPER_{str(mapper)}]")
483
+
484
+ for cs in np.arange(0.0, 10.0 + 1e-5, 0.1):
485
+ vocab.append(f"[CS_{cs:.1f}]")
486
+
487
+ for hitsounded in [True, False]:
488
+ vocab.append(f"[HITSOUNDED_{str(hitsounded).upper()}]")
489
+
490
+ for song_length in np.arange(0, self.max_song_length + 1e-5, self.song_length_step):
491
+ vocab.append(f"[SONG_LENGTH_{int(song_length)}]")
492
+
493
+ for song_position in np.arange(0.0, 1.0 + 1e-5, self.song_position_step):
494
+ vocab.append(f"[SONG_POSITION_{song_position:.2f}]")
495
+
496
+ for global_sv in np.arange(0.4, 3.6 + 1e-5, self.global_sv_step):
497
+ vocab.append(f"[GLOBAL_SV_{global_sv:.2f}]")
498
+
499
+ for mania_keycount in range(1, 19):
500
+ vocab.append(f"[MANIA_KEYCOUNT_{mania_keycount}]")
501
+
502
+ for hold_note_ratio in np.arange(0.0, 1.0 + 1e-5, self.hold_note_ratio_step):
503
+ vocab.append(f"[HOLD_NOTE_RATIO_{hold_note_ratio:.1f}]")
504
+
505
+ for scroll_speed_ratio in np.arange(0.0, 1.0 + 1e-5, self.scroll_speed_ratio_step):
506
+ vocab.append(f"[SCROLL_SPEED_RATIO_{scroll_speed_ratio:.1f}]")
507
+
508
+ for tag in self.tag_ids_to_names.values():
509
+ vocab.append(f"[TAG_{tag}]")
510
+
511
+ return {token: idx for idx, token in enumerate(vocab)}
512
+
513
+ def _tokenize_difficulty(self, metadata: CM3PMetadata):
514
+ difficulty = metadata.get('difficulty', None)
515
+ if difficulty is None:
516
+ return self.difficulty_unk_token
517
+ difficulty = np.clip(difficulty, self.min_difficulty, self.max_difficulty)
518
+ difficulty = round(difficulty / self.difficulty_step) * self.difficulty_step
519
+ return f"[DIFFICULTY_{difficulty:.1f}]"
520
+
521
+ def _tokenize_year(self, metadata: CM3PMetadata):
522
+ year = metadata.get('year', None)
523
+ if year is None:
524
+ return self.year_unk_token
525
+ year = np.clip(year, self.min_year, self.max_year)
526
+ return f"[YEAR_{year}]"
527
+
528
+ def _tokenize_mode(self, metadata: CM3PMetadata):
529
+ mode_str = metadata.get('mode', None)
530
+ if isinstance(mode_str, int):
531
+ mode_str = self.mode_ids_to_names.get(mode_str, None)
532
+ if mode_str is None or mode_str not in self.mode_names_to_ids:
533
+ return self.mode_unk_token
534
+ return f"[MODE_{str(mode_str)}]"
535
+
536
+ def _tokenize_status(self, metadata: CM3PMetadata):
537
+ status_str = metadata.get('status', None)
538
+ if isinstance(status_str, int):
539
+ status_str = self.status_ids_to_names.get(status_str, None)
540
+ if status_str is None or status_str not in self.status_names_to_ids:
541
+ return self.status_unk_token
542
+ return f"[STATUS_{str(status_str)}]"
543
+
544
+ def _tokenize_mapper(self, metadata: CM3PMetadata):
545
+ mapper_id = metadata.get('mapper', None)
546
+ if isinstance(mapper_id, str):
547
+ mapper_id = self.mapper_names_to_ids.get(mapper_id, None)
548
+ if mapper_id is None or mapper_id not in self.mapper_ids_to_names:
549
+ return self.mapper_unk_token
550
+ return f"[MAPPER_{str(mapper_id)}]"
551
+
552
+ def _tokenize_cs(self, metadata: CM3PMetadata):
553
+ cs = metadata.get('cs', None)
554
+ if cs is None:
555
+ return self.cs_unk_token
556
+ cs = np.clip(cs, 0.0, 10.0)
557
+ cs = round(cs / 0.1) * 0.1
558
+ return f"[CS_{cs:.1f}]"
559
+
560
+ def _tokenize_hitsounded(self, metadata: CM3PMetadata):
561
+ hitsounded = metadata.get('hitsounded', None)
562
+ if hitsounded is None:
563
+ return self.hitsounded_unk_token
564
+ return f"[HITSOUNDED_{str(hitsounded).upper()}]"
565
+
566
+ def _tokenize_song_length(self, metadata: CM3PMetadata):
567
+ song_length = metadata.get('song_length', None)
568
+ if song_length is None:
569
+ return self.song_length_unk_token
570
+ song_length = np.clip(song_length, 0, self.max_song_length)
571
+ song_length = round(song_length / self.song_length_step) * self.song_length_step
572
+ return f"[SONG_LENGTH_{int(song_length)}]"
573
+
574
+ def _tokenize_song_position(self, metadata: CM3PMetadata):
575
+ song_position = metadata.get('song_position', None)
576
+ if song_position is None:
577
+ return self.song_position_unk_token
578
+ song_position = np.clip(song_position, 0.0, 1.0)
579
+ song_position = round(song_position / self.song_position_step) * self.song_position_step
580
+ return f"[SONG_POSITION_{song_position:.2f}]"
581
+
582
+ def _tokenize_global_sv(self, metadata: CM3PMetadata):
583
+ global_sv = metadata.get('global_sv', None)
584
+ if global_sv is None:
585
+ return self.global_sv_unk_token
586
+ global_sv = np.clip(global_sv, 0.4, 3.6)
587
+ global_sv = round(global_sv / self.global_sv_step) * self.global_sv_step
588
+ return f"[GLOBAL_SV_{global_sv:.2f}]"
589
+
590
+ def _tokenize_mania_keycount(self, metadata: CM3PMetadata):
591
+ mania_keycount = metadata.get('mania_keycount', None)
592
+ if mania_keycount is None:
593
+ return self.mania_keycount_unk_token
594
+ mania_keycount = int(mania_keycount)
595
+ mania_keycount = np.clip(mania_keycount, 1, 18)
596
+ return f"[MANIA_KEYCOUNT_{mania_keycount}]"
597
+
598
+ def _tokenize_hold_note_ratio(self, metadata: CM3PMetadata):
599
+ hold_note_ratio = metadata.get('hold_note_ratio', None)
600
+ if hold_note_ratio is None:
601
+ return self.hold_note_ratio_unk_token
602
+ hold_note_ratio = np.clip(hold_note_ratio, 0.0, 1.0)
603
+ hold_note_ratio = round(hold_note_ratio / self.hold_note_ratio_step) * self.hold_note_ratio_step
604
+ return f"[HOLD_NOTE_RATIO_{hold_note_ratio:.1f}]"
605
+
606
+ def _tokenize_scroll_speed_ratio(self, metadata: CM3PMetadata):
607
+ scroll_speed_ratio = metadata.get('scroll_speed_ratio', None)
608
+ if scroll_speed_ratio is None:
609
+ return self.scroll_speed_ratio_unk_token
610
+ scroll_speed_ratio = np.clip(scroll_speed_ratio, 0.0, 1.0)
611
+ scroll_speed_ratio = round(scroll_speed_ratio / self.scroll_speed_ratio_step) * self.scroll_speed_ratio_step
612
+ return f"[SCROLL_SPEED_RATIO_{scroll_speed_ratio:.1f}]"
613
+
614
+ def _validate_tags(self, tags):
615
+ if tags is None:
616
+ return None
617
+ new_tags = []
618
+ for tag in tags:
619
+ if isinstance(tag, str) and tag in self.tag_names_to_ids:
620
+ new_tags.append(tag)
621
+ elif tag in self.tag_ids_to_names:
622
+ new_tags.append(self.tag_ids_to_names[tag])
623
+ return new_tags
624
+
625
+ def _tokenize_tags(self, metadata: CM3PMetadata):
626
+ tags = metadata.get('tags', None)
627
+ valid_tags = self._validate_tags(tags)
628
+ if not valid_tags:
629
+ return [self.tag_unk_token]
630
+ return [f"[TAG_{tag}]" for tag in valid_tags]
631
+
632
+ def _tokenize_metadata(self, metadata: CM3PMetadata):
633
+ tokens = []
634
+ if self.add_cls_token:
635
+ tokens.append(self.cls_token)
636
+ tokens.extend([
637
+ self.bos_token,
638
+ self._tokenize_difficulty(metadata),
639
+ self._tokenize_year(metadata),
640
+ self._tokenize_mode(metadata),
641
+ self._tokenize_status(metadata),
642
+ self._tokenize_mapper(metadata),
643
+ self._tokenize_cs(metadata),
644
+ self._tokenize_hitsounded(metadata),
645
+ self._tokenize_song_length(metadata),
646
+ self._tokenize_song_position(metadata),
647
+ self._tokenize_global_sv(metadata),
648
+ self._tokenize_mania_keycount(metadata),
649
+ self._tokenize_hold_note_ratio(metadata),
650
+ self._tokenize_scroll_speed_ratio(metadata),
651
+ ])
652
+ tokens.extend(self._tokenize_tags(metadata))
653
+ tokens.append(self.eos_token)
654
+ return tokens
655
+
656
+ def __call__(
657
+ self,
658
+ metadata: Optional[Union[CM3PMetadata, list[CM3PMetadata]]] = None,
659
+ padding: PaddingStrategy = PaddingStrategy.LONGEST,
660
+ truncation: TruncationStrategy = TruncationStrategy.DO_NOT_TRUNCATE,
661
+ max_length: Optional[int] = None,
662
+ return_tensors: Optional[str] = "pt",
663
+ **kwargs
664
+ ) -> BatchEncoding:
665
+ if isinstance(metadata, dict):
666
+ token_strings = self._tokenize_metadata(metadata)
667
+ token_ids = self.convert_tokens_to_ids(token_strings)
668
+ return self.prepare_for_model(
669
+ token_ids,
670
+ padding=padding,
671
+ truncation=truncation,
672
+ max_length=max_length,
673
+ return_tensors=return_tensors,
674
+ **kwargs,
675
+ )
676
+ elif isinstance(metadata, list):
677
+ all_token_ids = []
678
+ for m in metadata:
679
+ token_strings = self._tokenize_metadata(m)
680
+ token_ids = self.convert_tokens_to_ids(token_strings)
681
+ all_token_ids.append((token_ids, None))
682
+
683
+ return self._batch_prepare_for_model(
684
+ all_token_ids,
685
+ padding_strategy=PaddingStrategy(padding),
686
+ truncation_strategy=TruncationStrategy(truncation),
687
+ max_length=max_length,
688
+ return_tensors=return_tensors,
689
+ )
690
+
691
+ def metadata_variations(self, metadata: CM3PMetadata, num_variations: int = 1000) -> tuple[CM3PMetadata, int]:
692
+ def year_variations():
693
+ min_year = max(2007, self.min_year)
694
+ if metadata["year"] is None or (min_year > metadata["year"] or metadata["year"] > self.max_year):
695
+ return
696
+ for year in range(min_year, self.max_year + 1):
697
+ if year != metadata["year"]:
698
+ new_m = copy.deepcopy(metadata)
699
+ new_m["year"] = year
700
+ yield new_m, 1
701
+
702
+ def status_variations():
703
+ if metadata["status"] is None:
704
+ return
705
+ current_status = self.status_ids_to_names.get(metadata["status"], None) or metadata["status"]
706
+ if current_status not in self.status_names_to_ids:
707
+ return
708
+ for status in self.status_ids_to_names.values():
709
+ if status != current_status:
710
+ new_m = copy.deepcopy(metadata)
711
+ new_m["status"] = status
712
+ yield new_m, 2
713
+
714
+ def tags_variations():
715
+ # Replace/add/remove some tags
716
+ if metadata["tags"] is None or len(metadata["tags"]) <= 0:
717
+ return
718
+ current_tags = self._validate_tags(metadata["tags"])
719
+ if len(current_tags) <= 0:
720
+ return
721
+ for tag in self.tag_ids_to_names.values():
722
+ if tag not in current_tags:
723
+ new_m = copy.deepcopy(metadata)
724
+ new_m["tags"][np.random.randint(0, len(new_m["tags"]))] = tag
725
+ yield new_m, 3
726
+ for tag in self.tag_ids_to_names.values():
727
+ if tag not in current_tags:
728
+ new_m = copy.deepcopy(metadata)
729
+ new_m["tags"].insert(np.random.randint(0, len(new_m["tags"]) + 1), tag)
730
+ yield new_m, 3
731
+ if len(current_tags) <= 1:
732
+ return
733
+ for tag in current_tags:
734
+ new_m = copy.deepcopy(metadata)
735
+ new_tags = [t for t in current_tags if t != tag]
736
+ new_m["tags"] = new_tags
737
+ yield new_m, 3
738
+
739
+ def mapper_variations():
740
+ if metadata['mapper'] is None:
741
+ return
742
+ current_mapper = self.mapper_names_to_ids.get(metadata["mapper"], None) or metadata["mapper"]
743
+ mapper_variations = list(self.mapper_ids_to_names.keys())
744
+ if current_mapper in self.mapper_ids_to_names:
745
+ mapper_variations.remove(current_mapper)
746
+ # Randomly sample mappers to avoid too many variations
747
+ np.random.shuffle(mapper_variations)
748
+ for mapper in mapper_variations:
749
+ new_m = copy.deepcopy(metadata)
750
+ new_m["mapper"] = mapper
751
+ yield new_m, 4
752
+
753
+ def padding_variations():
754
+ while True:
755
+ yield CM3PMetadata(), -1
756
+
757
+ # Add variations with one field changed at a time
758
+ current_num_variations = 0
759
+ workers = [
760
+ year_variations(),
761
+ status_variations(),
762
+ tags_variations(),
763
+ mapper_variations(),
764
+ ]
765
+ padding_iterable = padding_variations()
766
+
767
+ index = 0
768
+ while current_num_variations < num_variations and len(workers) > 0:
769
+ try:
770
+ index = index % len(workers)
771
+ item = workers[index].__next__()
772
+ index += 1
773
+ current_num_variations += 1
774
+ yield item
775
+ except StopIteration:
776
+ workers.remove(workers[index])
777
+
778
+ while current_num_variations < num_variations:
779
+ current_num_variations += 1
780
+ yield padding_iterable.__next__()
781
+
782
+ @property
783
+ def vocab_size(self):
784
+ return len(self.vocab) + len(self._added_tokens_encoder)
785
+
786
+ def get_vocab(self):
787
+ return self.vocab | self._added_tokens_encoder
788
+
789
+ def _convert_token_to_id(self, token):
790
+ return self.vocab.get(token, self.vocab.get(self.unk_token))
791
+
792
+ def _convert_id_to_token(self, index):
793
+ return self.ids_to_tokens.get(index, self.unk_token)
794
+
795
+ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> tuple[str]:
796
+ if not save_directory:
797
+ raise ValueError("The save_directory must be specified.")
798
+
799
+ vocab_file = f"{save_directory}/{filename_prefix or ''}vocab.json"
800
+ with open(vocab_file, 'w', encoding='utf-8') as f:
801
+ json.dump(self.vocab, f, ensure_ascii=False)
802
+
803
+ return (vocab_file,)
804
+
805
+ AutoTokenizer.register(CM3PBeatmapConfig, CM3PBeatmapTokenizer)
806
+ AutoTokenizer.register(CM3PMetadataConfig, CM3PMetadataTokenizer)
807
+
808
+ __all__ = ["CM3PBeatmapTokenizer", "CM3PMetadataTokenizer", "CM3PMetadata", "merge_metadata_dicts"]