File size: 37,922 Bytes
4015376
 
 
 
 
 
 
 
 
63af269
4015376
 
 
63af269
 
4015376
63af269
4015376
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63af269
4015376
 
 
 
 
 
 
63af269
4015376
 
 
 
 
 
 
 
 
 
 
c488d69
4015376
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63af269
4015376
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c488d69
4015376
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c488d69
4015376
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63af269
4015376
 
 
 
 
 
 
 
 
 
 
 
 
63af269
4015376
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63af269
 
 
 
 
 
 
 
4015376
 
63af269
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4015376
 
63af269
 
 
4015376
 
 
63af269
4015376
 
63af269
 
 
 
 
 
 
 
4015376
63af269
4015376
63af269
 
 
 
 
 
4015376
63af269
 
4015376
 
 
 
 
 
63af269
 
 
4015376
 
63af269
 
4015376
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63af269
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4015376
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
import copy
import itertools
import math
import os
from os import PathLike
from pathlib import Path
from typing import Optional, Union, IO, TypedDict

import numpy as np
from huggingface_hub.errors import HfHubHTTPError
from pandas import Series
from slider import Beatmap, HoldNote
from transformers import WhisperFeatureExtractor, AutoProcessor, BatchEncoding
from transformers.dynamic_module_utils import custom_object_save
from transformers.tokenization_utils_base import TruncationStrategy, PreTrainedTokenizerBase
from transformers.utils import is_torch_available, PaddingStrategy, PROCESSOR_NAME, logging
from huggingface_hub import CommitOperationAdd, create_branch, create_commit

from .configuration_cm3p import CM3PConfig
from .parsing_cm3p import CM3PBeatmapParser, load_beatmap, get_song_length
from .tokenization_cm3p import CM3PBeatmapTokenizer, CM3PMetadataTokenizer, CM3PMetadata, merge_metadata_dicts

if is_torch_available():
    import torch

from transformers.audio_utils import AudioInput, make_list_of_audio, load_audio
from transformers.feature_extraction_utils import BatchFeature
from transformers.processing_utils import AudioKwargs, ProcessorMixin, CommonKwargs

logger = logging.get_logger(__name__)


def get_hold_note_ratio(beatmap: Beatmap) -> Optional[float]:
    notes = beatmap.hit_objects(stacking=False)

    if len(notes) == 0:
        return None

    hold_note_count = 0
    for note in notes:
        if isinstance(note, HoldNote):
            hold_note_count += 1
    return hold_note_count / len(notes)


def get_scroll_speed_ratio(beatmap: Beatmap) -> Optional[float]:
    # Number of scroll speed changes divided by number of distinct hit object times
    notes = beatmap.hit_objects(stacking=False)

    if len(notes) == 0:
        return None

    last_time = -1
    num_note_times = 0
    for note in notes:
        if note.time != last_time:
            num_note_times += 1
            last_time = note.time
    last_scroll_speed = -1
    num_scroll_speed_changes = 0
    for timing_point in beatmap.timing_points:
        if timing_point.parent is None:
            last_scroll_speed = 1
        else:
            scroll_speed = -100 / timing_point.ms_per_beat
            if scroll_speed != last_scroll_speed and last_scroll_speed != -1:
                num_scroll_speed_changes += 1
            last_scroll_speed = scroll_speed
    return num_scroll_speed_changes / num_note_times


def get_hitsounded_status(beatmap: Beatmap) -> bool:
    notes = beatmap.hit_objects(stacking=False)
    for note in notes:
        if note.hitsound != 0:
            return True
    return False


def get_difficulty(beatmap_metadata: Series, speed: float = 1.0) -> float:
    # StarRating is an array that gives the difficulty for the speeds:
    # 0.5, 0.75, 1.0, 1.25, 1.5, 1.75, 2.0
    # Linearly interpolate between the two closest speeds
    star_ratings = beatmap_metadata["StarRating"]
    speed_ratios = [0.5, 0.75, 1.0, 1.25, 1.5, 1.75, 2.0]
    return np.interp(speed, speed_ratios, star_ratings)


def get_metadata(

        beatmap_metadata: Series = None,

        beatmap: Beatmap = None,

        audio_samples: np.ndarray = None,

        sampling_rate: int = None,

        speed: float = 1.0,

        song_position: Optional[float] = None,

) -> CM3PMetadata:
    mode = beatmap.mode if beatmap is not None else beatmap_metadata["ModeInt"] if beatmap_metadata is not None else None
    circle_size = beatmap.circle_size if beatmap is not None else beatmap_metadata["Cs"] if beatmap_metadata is not None else None
    song_length = get_song_length(audio_samples, sampling_rate, beatmap)
    return CM3PMetadata(
        difficulty=get_difficulty(beatmap_metadata, speed) if beatmap_metadata is not None else None,
        year=beatmap_metadata["SubmittedDate"].year if beatmap_metadata is not None else None,
        mode=mode,
        status=beatmap_metadata["Status"] if beatmap_metadata is not None else None,
        mapper=beatmap_metadata["UserId"] if beatmap_metadata is not None else None,
        cs=circle_size if mode in [0, 2] is not None else None,
        hitsounded=get_hitsounded_status(beatmap) if beatmap is not None else None,
        song_length=song_length,
        song_position=song_position,
        global_sv=beatmap.slider_multiplier if mode in [0, 2] and beatmap is not None else None,
        mania_keycount=int(circle_size) if mode == 3 and beatmap is not None else None,
        hold_note_ratio=get_hold_note_ratio(beatmap) if mode == 3 and beatmap is not None else None,
        scroll_speed_ratio=get_scroll_speed_ratio(beatmap) if mode in [1, 3] and beatmap is not None else None,
        tags=beatmap_metadata["TopTagIds"].tolist() if beatmap_metadata is not None else None,
    )


class CM3PTokenizerKwargs(TypedDict, total=False):
    add_special_tokens: Optional[bool]
    padding: Union[bool, str, PaddingStrategy]
    truncation: Union[bool, str, TruncationStrategy]
    max_length: Optional[int]
    pad_to_multiple_of: Optional[int]
    return_token_type_ids: Optional[bool]
    return_attention_mask: Optional[bool]
    return_overflowing_tokens: Optional[bool]
    return_special_tokens_mask: Optional[bool]
    return_offsets_mapping: Optional[bool]
    return_length: Optional[bool]
    verbose: Optional[bool]
    padding_side: Optional[str]
    return_mm_token_type_ids: Optional[bool]


class CM3PBeatmapKwargs(CM3PTokenizerKwargs, total=False):
    window_length_sec: float
    window_stride_sec: float
    min_window_length_sec: float


class CM3PAudioKwargs(AudioKwargs, total=False):
    max_source_positions: Optional[int]
    hop_length: Optional[int]
    window_size: Optional[int]
    audio_length_per_tok: Optional[int]
    device: Optional[str]


# noinspection PyTypedDict
class CM3PProcessorKwargs(CommonKwargs, CM3PBeatmapKwargs, CM3PTokenizerKwargs, CM3PAudioKwargs, total=False):
    _defaults = {
        "beatmap_kwargs": {
            "max_length": 8000,
            "padding": PaddingStrategy.LONGEST,
            "truncation": TruncationStrategy.LONGEST_FIRST,
            "window_length_sec": 30.0,
            "window_stride_sec": 30.0,
            "min_window_length_sec": 1.0,
        },
        "metadata_kwargs": {
            "max_length": 128,
            "padding": PaddingStrategy.LONGEST,
            "truncation": TruncationStrategy.LONGEST_FIRST,
        },
        "audio_kwargs": {
            "sampling_rate": 16000,
            "padding": True,
            "truncation": False,
            "pad_to_multiple_of": 480000,
            "max_source_positions": 3000,
            "hop_length": 160,
            "window_size": 400,
            "audio_length_per_tok": 8,
            "device": "cpu",
        },
        "common_kwargs": {
            "return_tensors": "pt",
        },
    }

    common_kwargs: CommonKwargs = {
        **CommonKwargs.__annotations__,
    }
    beatmap_kwargs: CM3PBeatmapKwargs = {
        **CM3PTokenizerKwargs.__annotations__,
    }
    metadata_kwargs: CM3PTokenizerKwargs = {
        **CM3PTokenizerKwargs.__annotations__,
    }
    audio_kwargs: CM3PAudioKwargs = {
        **CM3PAudioKwargs.__annotations__,
    }


class CM3PProcessor(ProcessorMixin):
    r"""

    Constructs a CM3P processor which wraps [`WhisperFeatureExtractor`] and

    [`MistralCommonTokenizer`] into a single processor that inherits both the audio feature extraction and

    tokenizer functionalities.



    Args:

        audio_feature_extractor ([`WhisperFeatureExtractor`]):

            The feature extractor is a required input.

        beatmap_parser ([`CM3PBeatmapParser`]):

            The beatmap parser is a required input.

        beatmap_tokenizer ([`CM3PBeatmapTokenizer`]):

            The beatmap tokenizer is a required input.

        metadata_tokenizer ([`CM3PMetadataTokenizer`]):

            The metadata tokenizer is a required input.

        default_kwargs (`CM3PProcessorKwargs`, *optional*):

            Default keyword arguments for the processor. If not provided, the processor will use its own defaults

    """

    attributes = ["audio_feature_extractor", "beatmap_parser", "beatmap_tokenizer", "metadata_tokenizer"]
    audio_feature_extractor_class = "WhisperFeatureExtractor"
    beatmap_parser_class = "CM3PBeatmapParser"
    beatmap_tokenizer_class = "CM3PBeatmapTokenizer"
    metadata_tokenizer_class = "CM3PMetadataTokenizer"

    def __init__(

        self,

        audio_feature_extractor: WhisperFeatureExtractor,

        beatmap_parser: CM3PBeatmapParser,

        beatmap_tokenizer: CM3PBeatmapTokenizer,

        metadata_tokenizer: CM3PMetadataTokenizer,

        default_kwargs: Optional[CM3PProcessorKwargs] = None,

    ):
        self.audio_feature_extractor = audio_feature_extractor
        self.beatmap_parser = beatmap_parser
        self.beatmap_tokenizer = beatmap_tokenizer
        self.metadata_tokenizer = metadata_tokenizer
        self.audio_token = beatmap_tokenizer.audio_token

        # noinspection PyProtectedMember
        self.default_kwargs = default_kwargs or copy.deepcopy(CM3PProcessorKwargs._defaults)

        super().__init__(audio_feature_extractor, beatmap_parser, beatmap_tokenizer, metadata_tokenizer)

    def _pad_audio(

            self,

            audio_array: np.ndarray,

            window_size: int = 400,

            pad_to_multiple_of: Optional[int] = 480000,

            **_,

    ) -> np.ndarray:
        r"""Pad the audio array to the desired length.



        Args:

            audio_array: Audio data as a numpy array.

            sampling_rate: Sampling rate of the audio.



        Returns:

            Padded audio array.

        """
        if pad_to_multiple_of:
            next_multiple_of_chunk_frames = math.ceil(audio_array.shape[-1] / pad_to_multiple_of) * pad_to_multiple_of
            audio_array = np.pad(audio_array, (0, next_multiple_of_chunk_frames - audio_array.shape[-1]))
        elif audio_array.shape[-1] < window_size:
            # minimum length for audios is at least one spectrogram frame
            audio_array = np.pad(audio_array, (0, window_size - audio_array.shape[-1]))

        return audio_array

    def _encode_audio(

            self,

            audio: np.ndarray,

            hop_length: int = 160,

            audio_length_per_tok: int = 8,

            **kwargs,

    ) -> tuple[np.ndarray, int]:
        audio = self._pad_audio(audio, **kwargs)
        signal_length = audio.shape[0]

        # for spectrogram-based models, the waveform is downsampled by the hop_length when computing the log-mel
        if signal_length % hop_length != 0:
            signal_length = math.ceil(signal_length / hop_length - 1)
        else:
            signal_length = signal_length // hop_length

        num_audio_tokens = math.ceil(signal_length / audio_length_per_tok)

        return audio, num_audio_tokens

    def _retrieve_input_features(self, audio, max_source_positions, **kwargs) -> Union[torch.Tensor, np.ndarray]:
        """

        Handles specific logic of CM3P expected input features: audio arrays should be padded to next multiple of 480000 (duration is a multiple of 30s), see CM3PProcessorKwargs' default audio_kwargs.

        Then mel input features are extracted and stacked along batch dimension, splitting into chunks of max_source_positions.

        """
        return_tensors = kwargs.get("return_tensors", "pt")
        input_features_list = []
        for audio_array in audio:
            audio_inputs = self.audio_feature_extractor(audio_array, **kwargs)

            # let's split into chunks of max_source_positions, and then stack them along batch dimension
            input_features = audio_inputs["input_features"].reshape(
                self.audio_feature_extractor.feature_size, -1, max_source_positions
            )

            input_features_list.append(input_features.swapaxes(0, 1))

        if return_tensors == "pt":
            return torch.cat(input_features_list)

        return np.concatenate(input_features_list)

    def _load_audio(

        self,

        sampling_rate: int,

        audio: Union[str, list[str], Path, list[Path], AudioInput],

        audio_sampling_rate: Optional[Union[int, list[int]]] = None,

        speed: float = 1.0,

    ) -> list[np.ndarray]:
        """

        Helper method to load audio from various formats and return a list of audio buffers.

        """

        # convert Path objects to str
        if isinstance(audio, Path):
            audio = str(audio)
        if isinstance(audio, list) and all(isinstance(el, Path) for el in audio):
            audio = [str(el) for el in audio]

        # validate audio input
        is_str = isinstance(audio, str)
        is_list_of_str = isinstance(audio, list) and all(isinstance(el, str) for el in audio)
        is_list_of_audio = not (is_str or is_list_of_str)

        if is_list_of_audio:
            if audio_sampling_rate is None:
                # noinspection PyUnresolvedReferences
                logger.warning_once(
                    f"You've provided audio without specifying the sampling rate. It will be assumed to be {sampling_rate}, which can result in silent errors."
                )
                audio_sampling_rate = sampling_rate

        if is_str:
            audio = [load_audio(audio, sampling_rate=int(sampling_rate // speed))]
            audio_sampling_rate = sampling_rate
        elif is_list_of_str:
            audio = [load_audio(el, sampling_rate=int(sampling_rate // speed)) for el in audio]
            audio_sampling_rate = sampling_rate

        audio = make_list_of_audio(audio)

        if isinstance(audio_sampling_rate, int):
            audio_sampling_rate = [audio_sampling_rate] * len(audio)

        audio_buffers = []
        for array, s in zip(audio, audio_sampling_rate):
            array = np.asarray(array)
            # Convert to mono if needed
            if array.ndim == 2:
                array = array.mean(axis=1)
            # Resample if the sampling rate is different from the expected one
            if s != sampling_rate:
                import soxr
                array = soxr.resample(array, s, sampling_rate, quality="HQ")
            audio_buffers.append(array)

        return audio_buffers

    # noinspection PyTypedDict
    def _merge_kwargs(self, **kwargs) -> CM3PProcessorKwargs:
        output_kwargs = CM3PProcessorKwargs()
        nested_modalities = ["beatmap_kwargs", "metadata_kwargs", "audio_kwargs", "common_kwargs"]
        possible_modality_keywords = {"beatmap", "metadata", "audio"}
        used_keys = set()

        # pass defaults to output dictionary
        output_kwargs.update(copy.deepcopy(self.default_kwargs))

        # update modality kwargs with passed kwargs
        non_modality_kwargs = set(kwargs) - set(output_kwargs)
        for modality, output_kwarg in output_kwargs.items():
            for modality_key in CM3PProcessorKwargs.__annotations__[modality].__annotations__:
                # check if we received a structured kwarg dict or not to handle it correctly
                if modality in kwargs:
                    kwarg_value = kwargs[modality].pop(modality_key, "__empty__")
                    # check if this key was passed as a flat kwarg.
                    if kwarg_value != "__empty__" and modality_key in non_modality_kwargs:
                        raise ValueError(
                            f"Keyword argument {modality_key} was passed two times:\n"
                            f"in a dictionary for {modality} and as a **kwarg."
                        )
                elif modality_key in kwargs:
                    # we get a modality_key instead of popping it because modality-specific processors
                    # can have overlapping kwargs
                    kwarg_value = kwargs.get(modality_key, "__empty__")
                else:
                    kwarg_value = "__empty__"
                if not isinstance(kwarg_value, str) or kwarg_value != "__empty__":
                    output_kwarg[modality_key] = kwarg_value
                    used_keys.add(modality_key)

        # Determine if kwargs is a flat dictionary or contains nested dictionaries
        if any(key in nested_modalities for key in kwargs):
            # kwargs is dictionary-based, and some keys match modality names
            for modality, subdict in kwargs.items():
                if modality in nested_modalities:
                    for subkey, subvalue in subdict.items():
                        if subkey not in used_keys:
                            output_kwargs[modality][subkey] = subvalue
                            used_keys.add(subkey)
        else:
            # kwargs is a flat dictionary
            for key, kwarg in kwargs.items():
                if key not in used_keys:
                    if key in CM3PProcessorKwargs.__annotations__["common_kwargs"].__annotations__:
                        output_kwargs["common_kwargs"][key] = kwarg
                    elif key not in possible_modality_keywords:
                        # noinspection PyUnresolvedReferences
                        logger.warning_once(
                            f"Keyword argument `{key}` is not a valid argument for this processor and will be ignored."
                        )

        # all modality-specific kwargs are updated with common kwargs
        for kwarg in output_kwargs.values():
            kwarg.update(output_kwargs["common_kwargs"])
        return output_kwargs

    def __call__(

        self,

        metadata: Optional[Union[CM3PMetadata, list[CM3PMetadata]]] = None,

        beatmap: Optional[Union[str, list[str], PathLike, list[PathLike], IO[str], list[IO[str]], Beatmap, list[Beatmap]]] = None,

        audio: Optional[Union[str, list[str], Path, list[Path], AudioInput]] = None,

        audio_sampling_rate: Optional[Union[int, list[int]]] = None,

        speed: float = 1.0,

        multiply_metadata: bool = False,

        populate_metadata: bool = False,

        metadata_dropout_prob: float = 0.0,

        metadata_variations: int = 1,

        **kwargs,

    ):
        output_kwargs = self._merge_kwargs(**kwargs)

        beatmap_kwargs: CM3PTokenizerKwargs = output_kwargs["beatmap_kwargs"]
        metadata_kwargs: CM3PTokenizerKwargs = output_kwargs["metadata_kwargs"]
        audio_kwargs: CM3PAudioKwargs = output_kwargs["audio_kwargs"]
        common_kwargs: CommonKwargs = output_kwargs["common_kwargs"]

        window_length_sec = beatmap_kwargs.pop("window_length_sec")
        window_stride_sec = beatmap_kwargs.pop("window_stride_sec")
        min_window_length_sec = beatmap_kwargs.pop("min_window_length_sec", 1.0)
        max_length = beatmap_kwargs.get("max_length", 8000)
        metadata_max_length = metadata_kwargs.get("max_length", 128)
        sampling_rate = audio_kwargs["sampling_rate"]
        max_source_positions = audio_kwargs.get("max_source_positions", 3000)
        audio_kwargs["padding"] = False
        return_tensors = common_kwargs["return_tensors"]

        metadata_encoding, beatmap_encoding, num_audio_tokens, metadata_variation_classes = None, None, None, None

        if return_tensors is not None and return_tensors != "pt":
            raise ValueError(f"{self.__class__.__name__} only supports `return_tensors='pt'` or `return_tensors=None`.")

        if metadata is None and beatmap is None:
            raise ValueError("You have to specify either metadata or beatmap. Both cannot be none.")

        if audio is not None:
            audio = self._load_audio(
                sampling_rate,
                audio,
                audio_sampling_rate=audio_sampling_rate,
            )

        if beatmap is not None:
            if not isinstance(beatmap, list):
                beatmap = [beatmap]

            if audio is not None:
                if len(beatmap) != len(audio):
                    raise ValueError(
                        f"The number of beatmaps ({len(beatmap)}) must match the number of audio ({len(audio)})"
                    )
            else:
                audio = [None] * len(beatmap)

            if multiply_metadata or populate_metadata and metadata is not None:
                matched_metadata = metadata
                if not isinstance(matched_metadata, list):
                    matched_metadata = [matched_metadata]
                if (multiply_metadata or populate_metadata) and len(matched_metadata) != len(beatmap):
                    raise ValueError(
                        f"The number of metadata entries ({len(matched_metadata)}) must match the number of beatmaps ({len(beatmap)})"
                        "` if multiply_metadata` or `populate_metadata` is set to True."
                    )
            else:
                matched_metadata = [CM3PMetadata()] * len(beatmap) if populate_metadata else [None] * len(beatmap)

            new_metadata = []
            batch_start_ms = []
            batch_groups = []
            batch_audio = []
            batch_num_audio_tokens = []
            for b, m, audio_array in zip(beatmap, matched_metadata, audio):
                b: Beatmap = load_beatmap(b)
                song_length = get_song_length(audio_array, sampling_rate, b)
                beatmap_groups = self.beatmap_parser.parse_beatmap(b, speed=speed, song_length=song_length)

                def add_metadata(song_position: Optional[float] = None):
                    if populate_metadata:
                        new_metadata.append(merge_metadata_dicts(m, get_metadata(
                            beatmap=b,
                            audio_samples=audio_array,
                            sampling_rate=sampling_rate,
                            speed=speed,
                            song_position=song_position,
                        )))
                    else:
                        new_metadata.append(m)

                if not multiply_metadata:
                    add_metadata()

                # Loop through with sliding window
                groups_search_index = 0
                for start_sec in np.arange(0, song_length - min_window_length_sec, window_stride_sec):
                    end_sec = start_sec + window_length_sec

                    if audio_array is not None:
                        # Slice audio waveform
                        start_frame = int(start_sec * sampling_rate)
                        end_frame = int(end_sec * sampling_rate)
                        audio_slice = audio_array[start_frame:end_frame]
                        # Pad the audio array and calculate the number of audio tokens
                        audio_slice, num_audio_tokens = self._encode_audio(audio_slice, **audio_kwargs)
                    else:
                        audio_slice = None
                        num_audio_tokens = 0

                    # Find groups that fall within the current window
                    # Groups are sorted by time, so we can use a simple linear search from the last index
                    start_ms = start_sec * 1000
                    end_ms = end_sec * 1000
                    next_start_ms = (start_sec + window_stride_sec) * 1000
                    window_groups = []
                    for group in itertools.islice(beatmap_groups, groups_search_index, None):
                        if group.time < next_start_ms:
                            groups_search_index += 1

                        if group.time < start_ms:
                            continue
                        elif group.time < end_ms:
                            window_groups.append(group)
                        else:
                            break

                    batch_start_ms.append(start_ms)
                    batch_groups.append(window_groups)
                    batch_audio.append(audio_slice)
                    batch_num_audio_tokens.append(num_audio_tokens)

                    if multiply_metadata:
                        add_metadata(start_sec / song_length)

            if populate_metadata or multiply_metadata:
                metadata = new_metadata

            if len(batch_groups) > 0:
                beatmap_encoding = self.beatmap_tokenizer(
                    groups=batch_groups,
                    window_start_ms=batch_start_ms,
                    num_audio_tokens=batch_num_audio_tokens,
                    **beatmap_kwargs,
                )

                if all(a is not None for a in audio):
                    data = dict(beatmap_encoding)
                    data["input_features"] = self._retrieve_input_features(batch_audio, **audio_kwargs)
                    beatmap_encoding = BatchFeature(data, tensor_type=return_tensors)
            else:
                # No windows with hit objects were found, return empty encoding
                logger.warning("Warning: No windows with hit objects were found in the provided beatmap(s). Returning empty encoding.")
                beatmap_encoding = BatchEncoding(
                    {
                        "input_ids": torch.zeros((0, max_length), dtype=torch.long) if return_tensors == "pt" else [],
                        "attention_mask": torch.zeros((0, max_length), dtype=torch.long) if return_tensors == "pt" else [],
                    },
                    tensor_type=return_tensors,
                )
                if all(a is not None for a in audio):
                    data = dict(beatmap_encoding)
                    data["input_features"] = torch.zeros((0, self.audio_feature_extractor.feature_size, max_source_positions), dtype=torch.float) if return_tensors == "pt" else []
                    beatmap_encoding = BatchFeature(data, tensor_type=return_tensors)

        if metadata is not None and not (isinstance(metadata, list) and any(m is None for m in metadata)):
            if not isinstance(metadata, list):
                metadata = [metadata]

            if metadata_dropout_prob > 0.0:
                for m in metadata:
                    # Randomly drop out metadata fields
                    for key, value in m.items():
                        if value is not None and np.random.rand() < metadata_dropout_prob:
                            # noinspection PyTypedDict
                            m[key] = None

            if metadata_variations > 1:
                extended_metadata = []
                metadata_variation_classes = []
                for m in metadata:
                    m_vars, m_classes = zip(*self.metadata_tokenizer.metadata_variations(m, metadata_variations - 1))
                    extended_metadata.append(m)
                    extended_metadata.extend(m_vars)
                    metadata_variation_classes.append([0] + list(m_classes))  # Class 0 is the original metadata

                assert len(extended_metadata) == len(metadata) * metadata_variations
                metadata = extended_metadata

            if len(metadata) > 0:
                metadata_encoding = self.metadata_tokenizer(
                    metadata,
                    **metadata_kwargs,
                )
                if metadata_variations > 1:
                    # Reshape to (batch_size, variations, seq_len)
                    for k, v in metadata_encoding.items():
                        if return_tensors == "pt":
                            v = v.view(len(metadata) // metadata_variations, metadata_variations, -1)
                        else:
                            v = [v[i:i + metadata_variations] for i in range(0, len(v), metadata_variations)]
                        metadata_encoding[k] = v
                if metadata_variation_classes is not None:
                    metadata_encoding["metadata_variation_classes"] = torch.tensor(metadata_variation_classes, dtype=torch.long) if return_tensors == "pt" else metadata_variation_classes
            else:
                metadata_encoding = BatchEncoding(
                    {
                        "input_ids": torch.zeros((0, metadata_max_length), dtype=torch.long) if return_tensors == "pt" else [],
                        "attention_mask": torch.zeros((0, metadata_max_length), dtype=torch.long) if return_tensors == "pt" else [],
                    },
                    tensor_type=return_tensors,
                )

        if metadata_encoding is not None and beatmap_encoding is not None:
            beatmap_encoding["metadata_ids"] = metadata_encoding["input_ids"]
            beatmap_encoding["metadata_attention_mask"] = metadata_encoding["attention_mask"]
            if "metadata_variation_classes" in metadata_encoding:
                beatmap_encoding["metadata_variation_classes"] = metadata_encoding["metadata_variation_classes"]
            return beatmap_encoding
        elif beatmap_encoding is not None:
            return beatmap_encoding
        else:
            return metadata_encoding

    def batch_decode(self, *args, **kwargs):
        """

        This method forwards all its arguments to CM3PBeatmapTokenizer's [`~CM3PBeatmapTokenizer.batch_decode`]. Please

        refer to the docstring of this method for more information.

        """
        return self.beatmap_tokenizer.batch_decode(*args, **kwargs)

    def decode(self, *args, **kwargs):
        """

        This method forwards all its arguments to CM3PBeatmapTokenizer's [`~CM3PBeatmapTokenizer.decode`]. Please refer to

        the docstring of this method for more information.

        """
        return self.beatmap_tokenizer.decode(*args, **kwargs)

    def save_pretrained(self, save_directory, push_to_hub: bool = False, **kwargs):
        """

        Save processor and its sub-components, with support for AutoProcessor remote code.



        This is a lightly adapted version of ProcessorMixin.save_pretrained:

        - child attributes are saved into subfolders (audio_feature_extractor/, beatmap_parser/, ...);

        - when self._auto_class is set (via register_for_auto_class), custom_object_save is used

          so that auto_map and dynamic modules are written correctly.

        """
        os.makedirs(save_directory, exist_ok=True)

        # Handle Hub integration (same as ProcessorMixin / your existing code)
        if push_to_hub:
            commit_message = kwargs.pop("commit_message", None)
            repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1])
            repo_id = self._create_repo(repo_id, **kwargs)
            files_timestamps = self._get_files_timestamps(save_directory)
        else:
            commit_message = None
            repo_id = None
            files_timestamps = None

        # If we have a custom processor registered for an Auto class,
        # save its code and dependencies as a dynamic module and
        # populate the auto_map field in processor_config.json.
        if self._auto_class is not None:
            attrs = [getattr(self, attribute_name) for attribute_name in self.attributes]

            # For tokenizers, we pass their init_kwargs; for other objects, we pass the object itself.
            configs = []
            for a in attrs:
                if isinstance(a, PreTrainedTokenizerBase):
                    configs.append(a.init_kwargs)
                else:
                    configs.append(a)

            # Include the processor itself so its class is exported.
            configs.append(self)

            custom_object_save(self, save_directory, config=configs)

        # Save each sub-component into its own subfolder
        for attribute_name in self.attributes:
            attribute = getattr(self, attribute_name)

            # Include the processor class in the attribute config so this
            # processor can then be reloaded with the AutoProcessor API.
            if hasattr(attribute, "_set_processor_class"):
                # noinspection PyProtectedMember
                attribute._set_processor_class(self.__class__.__name__)

            attribute.save_pretrained(os.path.join(save_directory, attribute_name))

        # Clean up temporary auto_map injected into tokenizers, if any
        if self._auto_class is not None:
            for attribute_name in self.attributes:
                attribute = getattr(self, attribute_name)
                if isinstance(attribute, PreTrainedTokenizerBase) and "auto_map" in attribute.init_kwargs:
                    del attribute.init_kwargs["auto_map"]

        # Write processor_config.json (or equivalent)
        output_processor_file = os.path.join(save_directory, PROCESSOR_NAME)
        processor_dict = self.to_dict()

        # If processor_dict only contains processor_class, we skip writing the file,
        # matching the upstream behavior; otherwise we save it.
        if set(processor_dict.keys()) != {"processor_class"}:
            self.to_json_file(output_processor_file)
            # noinspection PyUnresolvedReferences
            logger.warning_once(f"processor saved in {output_processor_file}")

        # If requested, upload the modified files to the Hub
        if push_to_hub:
            self._upload_modified_files(
                save_directory,
                repo_id,
                files_timestamps,
                commit_message=commit_message,
                token=kwargs.get("token"),
                create_pr=kwargs.get("create_pr", False),
                revision=kwargs.get("revision"),
                commit_description=kwargs.get("commit_description"),
            )

        if set(processor_dict.keys()) == {"processor_class"}:
            return []
        return [output_processor_file]

    @classmethod
    def _get_arguments_from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
        subfolder = kwargs.pop("subfolder", None)
        args = []
        for attribute_name in cls.attributes:
            class_name = getattr(cls, f"{attribute_name}_class")
            attribute_class = cls.get_possibly_dynamic_module(class_name)
            attribute_subfolder = os.path.join(subfolder, attribute_name) if subfolder else attribute_name

            args.append(attribute_class.from_pretrained(
                pretrained_model_name_or_path,
                subfolder=attribute_subfolder,
                **kwargs
            ))

        return args

    def _upload_modified_files(

        self,

        working_dir: Union[str, os.PathLike],

        repo_id: str,

        files_timestamps: dict[str, float],

        commit_message: Optional[str] = None,

        token: Optional[Union[bool, str]] = None,

        create_pr: bool = False,

        revision: Optional[str] = None,

        commit_description: Optional[str] = None,

    ):
        """

        Uploads all modified files in `working_dir` to `repo_id`, based on `files_timestamps`.

        """
        working_dir = Path(working_dir)

        if commit_message is None:
            commit_message = "Upload CM3P processor"
        modified_files = [
            f
            for f in working_dir.iterdir()
            if str(f) not in files_timestamps or f.stat().st_mtime > files_timestamps[str(f)]
        ]

        # filter for actual files + folders at the root level
        modified_files = [
            f
            for f in modified_files
            if f.is_file() or f.is_dir()
        ]

        operations = []
        # upload standalone files
        for file in modified_files:
            if file.is_dir():
                # go over individual files of folder
                for f in file.iterdir():
                    operations.append(
                        CommitOperationAdd(
                            path_or_fileobj=f, path_in_repo=f.relative_to(working_dir).as_posix()
                        )
                    )
            else:
                operations.append(
                    CommitOperationAdd(path_or_fileobj=file, path_in_repo=file.relative_to(working_dir).as_posix())
                )

        if revision is not None and not revision.startswith("refs/pr"):
            try:
                create_branch(repo_id=repo_id, branch=revision, token=token, exist_ok=True)
            except HfHubHTTPError as e:
                if e.response.status_code == 403 and create_pr:
                    # If we are creating a PR on a repo we don't have access to, we can't create the branch.
                    # so let's assume the branch already exists. If it's not the case, an error will be raised when
                    # calling `create_commit` below.
                    pass
                else:
                    raise

        logger.info(f"Uploading the following files to {repo_id}: {','.join([f.relative_to(working_dir).as_posix() for f in modified_files])}")
        return create_commit(
            repo_id=repo_id,
            operations=operations,
            commit_message=commit_message,
            commit_description=commit_description,
            token=token,
            create_pr=create_pr,
            revision=revision,
        )

AutoProcessor.register(CM3PConfig, CM3PProcessor)

__all__ = ["CM3PProcessor", "get_metadata"]