OliBomby commited on
Commit
4015376
·
verified ·
1 Parent(s): 8781d03

Upload processing_cm3p.py

Browse files
Files changed (1) hide show
  1. processing_cm3p.py +704 -0
processing_cm3p.py ADDED
@@ -0,0 +1,704 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import itertools
3
+ import math
4
+ import os
5
+ from os import PathLike
6
+ from pathlib import Path
7
+ from typing import Optional, Union, IO, TypedDict
8
+
9
+ import numpy as np
10
+ import soxr
11
+ from pandas import Series
12
+ from slider import Beatmap, HoldNote
13
+ from transformers import WhisperFeatureExtractor, AutoProcessor, BatchEncoding
14
+ from transformers.tokenization_utils_base import TruncationStrategy
15
+ from transformers.utils import is_torch_available, PaddingStrategy, PROCESSOR_NAME, logging
16
+
17
+ from .configuration_cm3p import CM3PConfig
18
+ from .parsing_cm3p import CM3PBeatmapParser, load_beatmap, get_song_length
19
+ from .tokenization_cm3p import CM3PBeatmapTokenizer, CM3PMetadataTokenizer, CM3PMetadata, merge_metadata_dicts
20
+
21
+ if is_torch_available():
22
+ import torch
23
+
24
+ from transformers.audio_utils import AudioInput, make_list_of_audio, load_audio
25
+ from transformers.feature_extraction_utils import BatchFeature
26
+ from transformers.processing_utils import AudioKwargs, ProcessorMixin, CommonKwargs
27
+
28
+ logger = logging.get_logger(__name__)
29
+
30
+
31
+ def get_hold_note_ratio(beatmap: Beatmap) -> Optional[float]:
32
+ notes = beatmap.hit_objects(stacking=False)
33
+
34
+ if len(notes) == 0:
35
+ return None
36
+
37
+ hold_note_count = 0
38
+ for note in notes:
39
+ if isinstance(note, HoldNote):
40
+ hold_note_count += 1
41
+ return hold_note_count / len(notes)
42
+
43
+
44
+ def get_scroll_speed_ratio(beatmap: Beatmap) -> Optional[float]:
45
+ # Number of scroll speed changes divided by number of distinct hit object times
46
+ notes = beatmap.hit_objects(stacking=False)
47
+
48
+ if len(notes) == 0:
49
+ return None
50
+
51
+ last_time = -1
52
+ num_note_times = 0
53
+ for note in notes:
54
+ if note.time != last_time:
55
+ num_note_times += 1
56
+ last_time = note.time
57
+ last_scroll_speed = -1
58
+ num_scroll_speed_changes = 0
59
+ for timing_point in beatmap.timing_points:
60
+ if timing_point.parent is None:
61
+ last_scroll_speed = 1
62
+ else:
63
+ scroll_speed = -100 / timing_point.ms_per_beat
64
+ if scroll_speed != last_scroll_speed and last_scroll_speed != -1:
65
+ num_scroll_speed_changes += 1
66
+ last_scroll_speed = scroll_speed
67
+ return num_scroll_speed_changes / num_note_times
68
+
69
+
70
+ def get_hitsounded_status(beatmap: Beatmap) -> bool:
71
+ notes = beatmap.hit_objects(stacking=False)
72
+ for note in notes:
73
+ if note.hitsound != 0:
74
+ return True
75
+ return False
76
+
77
+
78
+ def get_difficulty(beatmap_metadata: Series, speed: float = 1.0) -> float:
79
+ # StarRating is an array that gives the difficulty for the speeds:
80
+ # 0.5, 0.75, 1.0, 1.25, 1.5, 1.75, 2.0
81
+ # Linearly interpolate between the two closest speeds
82
+ star_ratings = beatmap_metadata["StarRating"]
83
+ speed_ratios = [0.5, 0.75, 1.0, 1.25, 1.5, 1.75, 2.0]
84
+ return np.interp(speed, speed_ratios, star_ratings)
85
+
86
+
87
+ def get_metadata(
88
+ beatmap_metadata: Series = None,
89
+ beatmap: Beatmap = None,
90
+ audio_samples: np.ndarray = None,
91
+ sampling_rate: int = None,
92
+ speed: float = 1.0,
93
+ song_position: Optional[float] = None,
94
+ ) -> CM3PMetadata:
95
+ mode = beatmap.mode if beatmap is not None else beatmap_metadata["ModeInt"] if beatmap_metadata is not None else None
96
+ circle_size = beatmap.circle_size if beatmap is not None else beatmap_metadata["Cs"] if beatmap_metadata is not None else None
97
+ song_length = get_song_length(audio_samples, sampling_rate, beatmap)
98
+ return CM3PMetadata(
99
+ difficulty=get_difficulty(beatmap_metadata, speed) if beatmap_metadata is not None else None,
100
+ year=beatmap_metadata["SubmittedDate"].year if beatmap_metadata is not None else None,
101
+ mode=mode,
102
+ status=beatmap_metadata["Status"] if beatmap_metadata is not None else None,
103
+ mapper=beatmap_metadata["UserId"] if beatmap_metadata is not None else None,
104
+ cs=circle_size if mode in [0, 2] is not None else None,
105
+ hitsounded=get_hitsounded_status(beatmap) if beatmap is not None else None,
106
+ song_length=song_length,
107
+ song_position=song_position,
108
+ global_sv=beatmap.slider_multiplier if mode in [0, 2] and beatmap is not None else None,
109
+ mania_keycount=int(circle_size) if mode == 3 and beatmap is not None else None,
110
+ hold_note_ratio=get_hold_note_ratio(beatmap) if mode == 3 and beatmap is not None else None,
111
+ scroll_speed_ratio=get_scroll_speed_ratio(beatmap) if mode in [1, 3] and beatmap is not None else None,
112
+ tags=beatmap_metadata["TopTagIds"].tolist() if beatmap_metadata is not None else None,
113
+ )
114
+
115
+
116
+ class CM3PTokenizerKwargs(TypedDict, total=False):
117
+ add_special_tokens: Optional[bool]
118
+ padding: Union[bool, str, PaddingStrategy]
119
+ truncation: Union[bool, str, TruncationStrategy]
120
+ max_length: Optional[int]
121
+ pad_to_multiple_of: Optional[int]
122
+ return_token_type_ids: Optional[bool]
123
+ return_attention_mask: Optional[bool]
124
+ return_overflowing_tokens: Optional[bool]
125
+ return_special_tokens_mask: Optional[bool]
126
+ return_offsets_mapping: Optional[bool]
127
+ return_length: Optional[bool]
128
+ verbose: Optional[bool]
129
+ padding_side: Optional[str]
130
+ return_mm_token_type_ids: Optional[bool]
131
+
132
+
133
+ class CM3PBeatmapKwargs(CM3PTokenizerKwargs, total=False):
134
+ window_length_sec: float
135
+ window_stride_sec: float
136
+
137
+
138
+ class CM3PAudioKwargs(AudioKwargs, total=False):
139
+ max_source_positions: Optional[int]
140
+ hop_length: Optional[int]
141
+ window_size: Optional[int]
142
+ audio_length_per_tok: Optional[int]
143
+
144
+
145
+ # noinspection PyTypedDict
146
+ class CM3PProcessorKwargs(CommonKwargs, CM3PBeatmapKwargs, CM3PTokenizerKwargs, CM3PAudioKwargs, total=False):
147
+ _defaults = {
148
+ "beatmap_kwargs": {
149
+ "max_length": 8000,
150
+ "padding": PaddingStrategy.LONGEST,
151
+ "truncation": TruncationStrategy.LONGEST_FIRST,
152
+ "window_length_sec": 30.0,
153
+ "window_stride_sec": 30.0,
154
+ },
155
+ "metadata_kwargs": {
156
+ "max_length": 128,
157
+ "padding": PaddingStrategy.LONGEST,
158
+ "truncation": TruncationStrategy.LONGEST_FIRST,
159
+ },
160
+ "audio_kwargs": {
161
+ "sampling_rate": 16000,
162
+ "padding": True,
163
+ "truncation": False,
164
+ "pad_to_multiple_of": 480000,
165
+ "max_source_positions": 3000,
166
+ "hop_length": 160,
167
+ "window_size": 400,
168
+ "audio_length_per_tok": 8,
169
+ },
170
+ "common_kwargs": {
171
+ "return_tensors": "pt",
172
+ },
173
+ }
174
+
175
+ common_kwargs: CommonKwargs = {
176
+ **CommonKwargs.__annotations__,
177
+ }
178
+ beatmap_kwargs: CM3PBeatmapKwargs = {
179
+ **CM3PTokenizerKwargs.__annotations__,
180
+ }
181
+ metadata_kwargs: CM3PTokenizerKwargs = {
182
+ **CM3PTokenizerKwargs.__annotations__,
183
+ }
184
+ audio_kwargs: CM3PAudioKwargs = {
185
+ **CM3PAudioKwargs.__annotations__,
186
+ }
187
+
188
+
189
+ class CM3PProcessor(ProcessorMixin):
190
+ r"""
191
+ Constructs a CM3P processor which wraps [`WhisperFeatureExtractor`] and
192
+ [`MistralCommonTokenizer`] into a single processor that inherits both the audio feature extraction and
193
+ tokenizer functionalities.
194
+
195
+ Args:
196
+ audio_feature_extractor ([`WhisperFeatureExtractor`]):
197
+ The feature extractor is a required input.
198
+ beatmap_parser ([`CM3PBeatmapParser`]):
199
+ The beatmap parser is a required input.
200
+ beatmap_tokenizer ([`CM3PBeatmapTokenizer`]):
201
+ The beatmap tokenizer is a required input.
202
+ metadata_tokenizer ([`CM3PMetadataTokenizer`]):
203
+ The metadata tokenizer is a required input.
204
+ default_kwargs (`CM3PProcessorKwargs`, *optional*):
205
+ Default keyword arguments for the processor. If not provided, the processor will use its own defaults
206
+ """
207
+
208
+ attributes = ["audio_feature_extractor", "beatmap_parser", "beatmap_tokenizer", "metadata_tokenizer"]
209
+ audio_feature_extractor_class = "WhisperFeatureExtractor"
210
+ beatmap_parser_class = "CM3PBeatmapParser"
211
+ beatmap_tokenizer_class = "CM3PBeatmapTokenizer"
212
+ metadata_tokenizer_class = "CM3PMetadataTokenizer"
213
+
214
+ def __init__(
215
+ self,
216
+ audio_feature_extractor: WhisperFeatureExtractor,
217
+ beatmap_parser: CM3PBeatmapParser,
218
+ beatmap_tokenizer: CM3PBeatmapTokenizer,
219
+ metadata_tokenizer: CM3PMetadataTokenizer,
220
+ default_kwargs: Optional[CM3PProcessorKwargs] = None,
221
+ ):
222
+ self.audio_feature_extractor = audio_feature_extractor
223
+ self.beatmap_parser = beatmap_parser
224
+ self.beatmap_tokenizer = beatmap_tokenizer
225
+ self.metadata_tokenizer = metadata_tokenizer
226
+ self.audio_token = beatmap_tokenizer.audio_token
227
+
228
+ # noinspection PyProtectedMember
229
+ self.default_kwargs = default_kwargs or copy.deepcopy(CM3PProcessorKwargs._defaults)
230
+
231
+ super().__init__(audio_feature_extractor, beatmap_parser, beatmap_tokenizer, metadata_tokenizer)
232
+
233
+ def _pad_audio(
234
+ self,
235
+ audio_array: np.ndarray,
236
+ window_size: int = 400,
237
+ pad_to_multiple_of: Optional[int] = 480000,
238
+ **_,
239
+ ) -> np.ndarray:
240
+ r"""Pad the audio array to the desired length.
241
+
242
+ Args:
243
+ audio_array: Audio data as a numpy array.
244
+ sampling_rate: Sampling rate of the audio.
245
+
246
+ Returns:
247
+ Padded audio array.
248
+ """
249
+ if pad_to_multiple_of:
250
+ next_multiple_of_chunk_frames = math.ceil(audio_array.shape[-1] / pad_to_multiple_of) * pad_to_multiple_of
251
+ audio_array = np.pad(audio_array, (0, next_multiple_of_chunk_frames - audio_array.shape[-1]))
252
+ elif audio_array.shape[-1] < window_size:
253
+ # minimum length for audios is at least one spectrogram frame
254
+ audio_array = np.pad(audio_array, (0, window_size - audio_array.shape[-1]))
255
+
256
+ return audio_array
257
+
258
+ def _encode_audio(
259
+ self,
260
+ audio: np.ndarray,
261
+ hop_length: int = 160,
262
+ audio_length_per_tok: int = 8,
263
+ **kwargs,
264
+ ) -> tuple[np.ndarray, int]:
265
+ audio = self._pad_audio(audio, **kwargs)
266
+ signal_length = audio.shape[0]
267
+
268
+ # for spectrogram-based models, the waveform is downsampled by the hop_length when computing the log-mel
269
+ if signal_length % hop_length != 0:
270
+ signal_length = math.ceil(signal_length / hop_length - 1)
271
+ else:
272
+ signal_length = signal_length // hop_length
273
+
274
+ num_audio_tokens = math.ceil(signal_length / audio_length_per_tok)
275
+
276
+ return audio, num_audio_tokens
277
+
278
+ def _retrieve_input_features(self, audio, max_source_positions, **kwargs) -> Union[torch.Tensor, np.ndarray]:
279
+ """
280
+ Handles specific logic of CM3P expected input features: audio arrays should be padded to next multiple of 480000 (duration is a multiple of 30s), see CM3PProcessorKwargs' default audio_kwargs.
281
+ Then mel input features are extracted and stacked along batch dimension, splitting into chunks of max_source_positions.
282
+ """
283
+ return_tensors = kwargs.get("return_tensors", "pt")
284
+ input_features_list = []
285
+ for audio_array in audio:
286
+ audio_inputs = self.audio_feature_extractor(audio_array, **kwargs)
287
+
288
+ # let's split into chunks of max_source_positions, and then stack them along batch dimension
289
+ input_features = audio_inputs["input_features"].reshape(
290
+ self.audio_feature_extractor.feature_size, -1, max_source_positions
291
+ )
292
+
293
+ input_features_list.append(input_features.swapaxes(0, 1))
294
+
295
+ if return_tensors == "pt":
296
+ return torch.cat(input_features_list)
297
+
298
+ return np.concatenate(input_features_list)
299
+
300
+ def _load_audio(
301
+ self,
302
+ sampling_rate: int,
303
+ audio: Union[str, list[str], Path, list[Path], AudioInput],
304
+ audio_sampling_rate: Optional[Union[int, list[int]]] = None,
305
+ speed: float = 1.0,
306
+ ) -> list[np.ndarray]:
307
+ """
308
+ Helper method to load audio from various formats and return a list of audio buffers.
309
+ """
310
+
311
+ # convert Path objects to str
312
+ if isinstance(audio, Path):
313
+ audio = str(audio)
314
+ if isinstance(audio, list) and all(isinstance(el, Path) for el in audio):
315
+ audio = [str(el) for el in audio]
316
+
317
+ # validate audio input
318
+ is_str = isinstance(audio, str)
319
+ is_list_of_str = isinstance(audio, list) and all(isinstance(el, str) for el in audio)
320
+ is_list_of_audio = not (is_str or is_list_of_str)
321
+
322
+ if is_list_of_audio:
323
+ if audio_sampling_rate is None:
324
+ # noinspection PyUnresolvedReferences
325
+ logger.warning_once(
326
+ f"You've provided audio without specifying the sampling rate. It will be assumed to be {sampling_rate}, which can result in silent errors."
327
+ )
328
+ audio_sampling_rate = sampling_rate
329
+
330
+ if is_str:
331
+ audio = [load_audio(audio, sampling_rate=int(sampling_rate // speed))]
332
+ audio_sampling_rate = sampling_rate
333
+ elif is_list_of_str:
334
+ audio = [load_audio(el, sampling_rate=int(sampling_rate // speed)) for el in audio]
335
+ audio_sampling_rate = sampling_rate
336
+
337
+ audio = make_list_of_audio(audio)
338
+
339
+ if isinstance(audio_sampling_rate, int):
340
+ audio_sampling_rate = [audio_sampling_rate] * len(audio)
341
+
342
+ audio_buffers = []
343
+ for array, s in zip(audio, audio_sampling_rate):
344
+ array = np.asarray(array)
345
+ # Convert to mono if needed
346
+ if array.ndim == 2:
347
+ array = array.mean(axis=1)
348
+ # Resample if the sampling rate is different from the expected one
349
+ if s != sampling_rate:
350
+ array = soxr.resample(array, s, sampling_rate, quality="HQ")
351
+ audio_buffers.append(array)
352
+
353
+ return audio_buffers
354
+
355
+ # noinspection PyTypedDict
356
+ def _merge_kwargs(self, **kwargs) -> CM3PProcessorKwargs:
357
+ output_kwargs = CM3PProcessorKwargs()
358
+ nested_modalities = ["beatmap_kwargs", "metadata_kwargs", "audio_kwargs", "common_kwargs"]
359
+ possible_modality_keywords = {"beatmap", "metadata", "audio"}
360
+ used_keys = set()
361
+
362
+ # pass defaults to output dictionary
363
+ output_kwargs.update(copy.deepcopy(self.default_kwargs))
364
+
365
+ # update modality kwargs with passed kwargs
366
+ non_modality_kwargs = set(kwargs) - set(output_kwargs)
367
+ for modality, output_kwarg in output_kwargs.items():
368
+ for modality_key in CM3PProcessorKwargs.__annotations__[modality].__annotations__:
369
+ # check if we received a structured kwarg dict or not to handle it correctly
370
+ if modality in kwargs:
371
+ kwarg_value = kwargs[modality].pop(modality_key, "__empty__")
372
+ # check if this key was passed as a flat kwarg.
373
+ if kwarg_value != "__empty__" and modality_key in non_modality_kwargs:
374
+ raise ValueError(
375
+ f"Keyword argument {modality_key} was passed two times:\n"
376
+ f"in a dictionary for {modality} and as a **kwarg."
377
+ )
378
+ elif modality_key in kwargs:
379
+ # we get a modality_key instead of popping it because modality-specific processors
380
+ # can have overlapping kwargs
381
+ kwarg_value = kwargs.get(modality_key, "__empty__")
382
+ else:
383
+ kwarg_value = "__empty__"
384
+ if not isinstance(kwarg_value, str) or kwarg_value != "__empty__":
385
+ output_kwarg[modality_key] = kwarg_value
386
+ used_keys.add(modality_key)
387
+
388
+ # Determine if kwargs is a flat dictionary or contains nested dictionaries
389
+ if any(key in nested_modalities for key in kwargs):
390
+ # kwargs is dictionary-based, and some keys match modality names
391
+ for modality, subdict in kwargs.items():
392
+ if modality in nested_modalities:
393
+ for subkey, subvalue in subdict.items():
394
+ if subkey not in used_keys:
395
+ output_kwargs[modality][subkey] = subvalue
396
+ used_keys.add(subkey)
397
+ else:
398
+ # kwargs is a flat dictionary
399
+ for key, kwarg in kwargs.items():
400
+ if key not in used_keys:
401
+ if key in CM3PProcessorKwargs.__annotations__["common_kwargs"].__annotations__:
402
+ output_kwargs["common_kwargs"][key] = kwarg
403
+ elif key not in possible_modality_keywords:
404
+ # noinspection PyUnresolvedReferences
405
+ logger.warning_once(
406
+ f"Keyword argument `{key}` is not a valid argument for this processor and will be ignored."
407
+ )
408
+
409
+ # all modality-specific kwargs are updated with common kwargs
410
+ for kwarg in output_kwargs.values():
411
+ kwarg.update(output_kwargs["common_kwargs"])
412
+ return output_kwargs
413
+
414
+ def __call__(
415
+ self,
416
+ metadata: Optional[Union[CM3PMetadata, list[CM3PMetadata]]] = None,
417
+ beatmap: Optional[Union[str, list[str], PathLike, list[PathLike], IO[str], list[IO[str]], Beatmap, list[Beatmap]]] = None,
418
+ audio: Optional[Union[str, list[str], Path, list[Path], AudioInput]] = None,
419
+ audio_sampling_rate: Optional[Union[int, list[int]]] = None,
420
+ speed: float = 1.0,
421
+ multiply_metadata: bool = False,
422
+ populate_metadata: bool = False,
423
+ metadata_dropout_prob: float = 0.0,
424
+ metadata_variations: int = 1,
425
+ **kwargs,
426
+ ):
427
+ output_kwargs = self._merge_kwargs(**kwargs)
428
+
429
+ beatmap_kwargs: CM3PTokenizerKwargs = output_kwargs["beatmap_kwargs"]
430
+ metadata_kwargs: CM3PTokenizerKwargs = output_kwargs["metadata_kwargs"]
431
+ audio_kwargs: CM3PAudioKwargs = output_kwargs["audio_kwargs"]
432
+ common_kwargs: CommonKwargs = output_kwargs["common_kwargs"]
433
+
434
+ window_length_sec = beatmap_kwargs.pop("window_length_sec")
435
+ window_stride_sec = beatmap_kwargs.pop("window_stride_sec")
436
+ max_length = beatmap_kwargs.get("max_length", 8000)
437
+ metadata_max_length = metadata_kwargs.get("max_length", 128)
438
+ sampling_rate = audio_kwargs["sampling_rate"]
439
+ max_source_positions = audio_kwargs.get("max_source_positions", 3000)
440
+ audio_kwargs["padding"] = False
441
+ return_tensors = common_kwargs["return_tensors"]
442
+
443
+ metadata_encoding, beatmap_encoding, num_audio_tokens, metadata_variation_classes = None, None, None, None
444
+
445
+ if return_tensors is not None and return_tensors != "pt":
446
+ raise ValueError(f"{self.__class__.__name__} only supports `return_tensors='pt'` or `return_tensors=None`.")
447
+
448
+ if metadata is None and beatmap is None:
449
+ raise ValueError("You have to specify either metadata or beatmap. Both cannot be none.")
450
+
451
+ if audio is not None:
452
+ audio = self._load_audio(
453
+ sampling_rate,
454
+ audio,
455
+ audio_sampling_rate=audio_sampling_rate,
456
+ )
457
+
458
+ if beatmap is not None:
459
+ if not isinstance(beatmap, list):
460
+ beatmap = [beatmap]
461
+
462
+ if audio is not None:
463
+ if len(beatmap) != len(audio):
464
+ raise ValueError(
465
+ f"The number of beatmaps ({len(beatmap)}) must match the number of audio ({len(audio)})"
466
+ )
467
+ else:
468
+ audio = [None] * len(beatmap)
469
+
470
+ if multiply_metadata or populate_metadata and metadata is not None:
471
+ matched_metadata = metadata
472
+ if not isinstance(matched_metadata, list):
473
+ matched_metadata = [matched_metadata]
474
+ if (multiply_metadata or populate_metadata) and len(matched_metadata) != len(beatmap):
475
+ raise ValueError(
476
+ f"The number of metadata entries ({len(matched_metadata)}) must match the number of beatmaps ({len(beatmap)})"
477
+ "` if multiply_metadata` or `populate_metadata` is set to True."
478
+ )
479
+ else:
480
+ matched_metadata = [CM3PMetadata()] * len(beatmap) if populate_metadata else [None] * len(beatmap)
481
+
482
+ new_metadata = []
483
+ batch_start_ms = []
484
+ batch_groups = []
485
+ batch_audio = []
486
+ batch_num_audio_tokens = []
487
+ for b, m, audio_array in zip(beatmap, matched_metadata, audio):
488
+ b: Beatmap = load_beatmap(b)
489
+ song_length = get_song_length(audio_array, sampling_rate, b)
490
+ beatmap_groups = self.beatmap_parser.parse_beatmap(b, speed=speed, song_length=song_length)
491
+
492
+ def add_metadata(song_position: Optional[float] = None):
493
+ if populate_metadata:
494
+ new_metadata.append(merge_metadata_dicts(m, get_metadata(
495
+ beatmap=b,
496
+ audio_samples=audio_array,
497
+ sampling_rate=sampling_rate,
498
+ speed=speed,
499
+ song_position=song_position,
500
+ )))
501
+ else:
502
+ new_metadata.append(m)
503
+
504
+ if not multiply_metadata:
505
+ add_metadata()
506
+
507
+ # Loop through with sliding window
508
+ groups_search_index = 0
509
+ min_window_length_sec = 8
510
+ for start_sec in np.arange(0, song_length - min_window_length_sec, window_stride_sec):
511
+ end_sec = start_sec + window_length_sec
512
+
513
+ if audio_array is not None:
514
+ # Slice audio waveform
515
+ start_frame = int(start_sec * sampling_rate)
516
+ end_frame = int(end_sec * sampling_rate)
517
+ audio_slice = audio_array[start_frame:end_frame]
518
+ # Pad the audio array and calculate the number of audio tokens
519
+ audio_slice, num_audio_tokens = self._encode_audio(audio_slice, **audio_kwargs)
520
+ else:
521
+ audio_slice = None
522
+ num_audio_tokens = 0
523
+
524
+ # Find groups that fall within the current window
525
+ # Groups are sorted by time, so we can use a simple linear search from the last index
526
+ start_ms = start_sec * 1000
527
+ end_ms = end_sec * 1000
528
+ next_start_ms = (start_sec + window_stride_sec) * 1000
529
+ window_groups = []
530
+ for group in itertools.islice(beatmap_groups, groups_search_index, None):
531
+ if group.time < next_start_ms:
532
+ groups_search_index += 1
533
+
534
+ if group.time < start_ms:
535
+ continue
536
+ elif group.time < end_ms:
537
+ window_groups.append(group)
538
+ else:
539
+ break
540
+
541
+ batch_start_ms.append(start_ms)
542
+ batch_groups.append(window_groups)
543
+ batch_audio.append(audio_slice)
544
+ batch_num_audio_tokens.append(num_audio_tokens)
545
+
546
+ if multiply_metadata:
547
+ add_metadata(start_sec / song_length)
548
+
549
+ if populate_metadata or multiply_metadata:
550
+ metadata = new_metadata
551
+
552
+ if len(batch_groups) > 0:
553
+ beatmap_encoding = self.beatmap_tokenizer(
554
+ groups=batch_groups,
555
+ window_start_ms=batch_start_ms,
556
+ num_audio_tokens=batch_num_audio_tokens,
557
+ **beatmap_kwargs,
558
+ )
559
+
560
+ if audio is not None:
561
+ data = dict(beatmap_encoding)
562
+ data["input_features"] = self._retrieve_input_features(batch_audio, **audio_kwargs)
563
+ beatmap_encoding = BatchFeature(data, tensor_type=return_tensors)
564
+ else:
565
+ # No windows with hit objects were found, return empty encoding
566
+ logger.warning("Warning: No windows with hit objects were found in the provided beatmap(s). Returning empty encoding.")
567
+ beatmap_encoding = BatchEncoding(
568
+ {
569
+ "input_ids": torch.zeros((0, max_length), dtype=torch.long) if return_tensors == "pt" else [],
570
+ "attention_mask": torch.zeros((0, max_length), dtype=torch.long) if return_tensors == "pt" else [],
571
+ },
572
+ tensor_type=return_tensors,
573
+ )
574
+ if audio is not None:
575
+ data = dict(beatmap_encoding)
576
+ data["input_features"] = torch.zeros((0, self.audio_feature_extractor.feature_size, max_source_positions), dtype=torch.float) if return_tensors == "pt" else []
577
+ beatmap_encoding = BatchFeature(data, tensor_type=return_tensors)
578
+
579
+ if metadata is not None and not (isinstance(metadata, list) and any(m is None for m in metadata)):
580
+ if not isinstance(metadata, list):
581
+ metadata = [metadata]
582
+
583
+ if metadata_dropout_prob > 0.0:
584
+ for m in metadata:
585
+ # Randomly drop out metadata fields
586
+ for key, value in m.items():
587
+ if value is not None and np.random.rand() < metadata_dropout_prob:
588
+ # noinspection PyTypedDict
589
+ m[key] = None
590
+
591
+ if metadata_variations > 1:
592
+ extended_metadata = []
593
+ metadata_variation_classes = []
594
+ for m in metadata:
595
+ m_vars, m_classes = zip(*self.metadata_tokenizer.metadata_variations(m, metadata_variations - 1))
596
+ extended_metadata.append(m)
597
+ extended_metadata.extend(m_vars)
598
+ metadata_variation_classes.append([0] + list(m_classes)) # Class 0 is the original metadata
599
+
600
+ assert len(extended_metadata) == len(metadata) * metadata_variations
601
+ metadata = extended_metadata
602
+
603
+ if len(metadata) > 0:
604
+ metadata_encoding = self.metadata_tokenizer(
605
+ metadata,
606
+ **metadata_kwargs,
607
+ )
608
+ if metadata_variations > 1:
609
+ # Reshape to (batch_size, variations, seq_len)
610
+ for k, v in metadata_encoding.items():
611
+ if return_tensors == "pt":
612
+ v = v.view(len(metadata) // metadata_variations, metadata_variations, -1)
613
+ else:
614
+ v = [v[i:i + metadata_variations] for i in range(0, len(v), metadata_variations)]
615
+ metadata_encoding[k] = v
616
+ if metadata_variation_classes is not None:
617
+ metadata_encoding["metadata_variation_classes"] = torch.tensor(metadata_variation_classes, dtype=torch.long) if return_tensors == "pt" else metadata_variation_classes
618
+ else:
619
+ metadata_encoding = BatchEncoding(
620
+ {
621
+ "input_ids": torch.zeros((0, metadata_max_length), dtype=torch.long) if return_tensors == "pt" else [],
622
+ "attention_mask": torch.zeros((0, metadata_max_length), dtype=torch.long) if return_tensors == "pt" else [],
623
+ },
624
+ tensor_type=return_tensors,
625
+ )
626
+
627
+ if metadata_encoding is not None and beatmap_encoding is not None:
628
+ beatmap_encoding["metadata_ids"] = metadata_encoding["input_ids"]
629
+ beatmap_encoding["metadata_attention_mask"] = metadata_encoding["attention_mask"]
630
+ if "metadata_variation_classes" in metadata_encoding:
631
+ beatmap_encoding["metadata_variation_classes"] = metadata_encoding["metadata_variation_classes"]
632
+ return beatmap_encoding
633
+ elif beatmap_encoding is not None:
634
+ return beatmap_encoding
635
+ else:
636
+ return metadata_encoding
637
+
638
+ def batch_decode(self, *args, **kwargs):
639
+ """
640
+ This method forwards all its arguments to CM3PBeatmapTokenizer's [`~CM3PBeatmapTokenizer.batch_decode`]. Please
641
+ refer to the docstring of this method for more information.
642
+ """
643
+ return self.beatmap_tokenizer.batch_decode(*args, **kwargs)
644
+
645
+ def decode(self, *args, **kwargs):
646
+ """
647
+ This method forwards all its arguments to CM3PBeatmapTokenizer's [`~CM3PBeatmapTokenizer.decode`]. Please refer to
648
+ the docstring of this method for more information.
649
+ """
650
+ return self.beatmap_tokenizer.decode(*args, **kwargs)
651
+
652
+ def save_pretrained(self, save_directory, push_to_hub: bool = False, **kwargs):
653
+ os.makedirs(save_directory, exist_ok=True)
654
+
655
+ for attribute_name in self.attributes:
656
+ attribute = getattr(self, attribute_name)
657
+ # Include the processor class in the attribute config so this processor can then be reloaded with the
658
+ # `AutoProcessor` API.
659
+ if hasattr(attribute, "_set_processor_class"):
660
+ # noinspection PyProtectedMember
661
+ attribute._set_processor_class(self.__class__.__name__)
662
+ attribute.save_pretrained(os.path.join(save_directory, attribute_name))
663
+
664
+ output_processor_file = os.path.join(save_directory, PROCESSOR_NAME)
665
+ self.to_json_file(output_processor_file)
666
+ # noinspection PyUnresolvedReferences
667
+ logger.warning_once(f"processor saved in {output_processor_file}")
668
+
669
+ if push_to_hub:
670
+ commit_message = kwargs.pop("commit_message", None)
671
+ repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1])
672
+ repo_id = self._create_repo(repo_id, **kwargs)
673
+ files_timestamps = self._get_files_timestamps(save_directory)
674
+
675
+ self._upload_modified_files(
676
+ save_directory,
677
+ repo_id,
678
+ files_timestamps,
679
+ commit_message=commit_message,
680
+ token=kwargs.get("token"),
681
+ )
682
+
683
+ return [output_processor_file]
684
+
685
+ @classmethod
686
+ def _get_arguments_from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
687
+ subfolder = kwargs.pop("subfolder", None)
688
+ args = []
689
+ for attribute_name in cls.attributes:
690
+ class_name = getattr(cls, f"{attribute_name}_class")
691
+ attribute_class = cls.get_possibly_dynamic_module(class_name)
692
+ attribute_subfolder = os.path.join(subfolder, attribute_name) if subfolder else attribute_name
693
+
694
+ args.append(attribute_class.from_pretrained(
695
+ pretrained_model_name_or_path,
696
+ subfolder=attribute_subfolder,
697
+ **kwargs
698
+ ))
699
+
700
+ return args
701
+
702
+ AutoProcessor.register(CM3PConfig, CM3PProcessor)
703
+
704
+ __all__ = ["CM3PProcessor", "get_metadata"]