File size: 8,624 Bytes
65965c8
 
56ce0c8
65965c8
 
 
 
 
 
 
56ce0c8
24662af
65965c8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24662af
 
 
 
 
65965c8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24662af
 
65965c8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
038931f
 
65965c8
 
24662af
78277d7
65965c8
 
 
 
038931f
65965c8
 
 
 
038931f
65965c8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24662af
 
 
 
 
 
 
 
 
 
65965c8
24662af
 
 
65965c8
 
 
 
 
 
 
 
 
038931f
65965c8
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
import os
from pathlib import Path
from typing import Union

from transformers import BatchEncoding, PythonBackend
from transformers.tokenization_utils_base import TruncationStrategy
from transformers.utils.generic import PaddingStrategy, TensorType

try:
    from miditok import PerTok, TokSequence
    from symusic.types import Score
    import symusic
except ImportError:
    raise ImportError(
        "The `miditok` library is required for processing MIDI files. "
        "Please install it with `pip install miditok`."
    )


class Song2MIDIPerTokTokenizer(PythonBackend):
    vocab_files_names = {"vocab_file": "vocab.json"}

    def __init__(
        self,
        vocab_file: str | os.PathLike | Path,
        unk_token: str = "UNK_None",
        bos_token: str = "BOS_None",
        eos_token: str = "EOS_None",
        pad_token: str = "PAD_None",
        **kwargs,
    ):
        self._tokenizer = PerTok(params=vocab_file)
        
        # PerTok as of miditok version 3.0.6.post1 does not load position token locations from the vocab file.
        # use_position_toks workaround
        if self._tokenizer.use_position_toks and not getattr(self._tokenizer, "position_locations", None):
            self._tokenizer.position_locations = self._tokenizer._create_position_tok_locations()

        self._decoder = {value: key for key, value in self._tokenizer.vocab.items()}

        super().__init__(
            unk_token=unk_token,
            bos_token=bos_token,
            eos_token=eos_token,
            pad_token=pad_token,
            **kwargs,
        )

    @property
    def vocab_size(self):
        return len(self._tokenizer)

    def get_vocab(self):
        return self._tokenizer.vocab

    def _encode_plus(
        self,
        text: Union[Score, Path, bytes, list[Union[Score, Path, bytes]], list[int]],
        text_pair: Union[Score, Path, list[Union[Score, Path]], list[int], None] = None,
        add_special_tokens: bool = True,
        padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
        truncation_strategy: TruncationStrategy = TruncationStrategy.DO_NOT_TRUNCATE,
        max_length: int | None = None,
        stride: int = 0,
        pad_to_multiple_of: int | None = None,
        padding_side: str | None = None,
        return_tensors: str | TensorType | None = None,
        return_token_type_ids: bool | None = None,
        return_attention_mask: bool | None = None,
        return_overflowing_tokens: bool = False,
        return_special_tokens_mask: bool = False,
        return_length: bool = False,
        verbose: bool = True,
        **kwargs,
    ): # ty: ignore[invalid-method-override]
        midi = text
        midi_pair = text_pair

        # From https://github.com/huggingface/transformers/blob/main/src/transformers/tokenization_python.py (v5.3.0)
        is_batched = isinstance(midi, (list, tuple)) and (
            (not midi) or (midi and isinstance(midi[0], (str, Path, Score, bytes)))
        )

        if is_batched:
            if midi_pair is not None:
                if not isinstance(midi_pair, (list, tuple)) or len(midi_pair) != len(
                    midi
                ):
                    raise ValueError(
                        "If `midi` is a batch, `midi_pair` must be a batch of the same length."
                    )
            pairs = midi_pair if midi_pair is not None else [None] * len(midi)

            batch_outputs = {}
            for current_midi, current_pair in zip(midi, pairs):
                current_output = self._encode_plus(
                    text=current_midi,
                    text_pair=current_pair,
                    add_special_tokens=add_special_tokens,
                    padding_strategy=PaddingStrategy.DO_NOT_PAD,  # we pad in batch afterward
                    truncation_strategy=truncation_strategy,
                    max_length=max_length,
                    stride=stride,
                    pad_to_multiple_of=None,  # we pad in batch afterward
                    padding_side=None,  # we pad in batch afterward
                    return_tensors=None,  # we convert the whole batch to tensors at the end
                    return_token_type_ids=return_token_type_ids,
                    return_attention_mask=False,  # we pad in batch afterward
                    return_overflowing_tokens=return_overflowing_tokens,
                    return_special_tokens_mask=return_special_tokens_mask,
                    return_length=return_length,
                    verbose=verbose,
                    **kwargs,
                )
                for key, value in current_output.items():
                    batch_outputs.setdefault(key, []).append(value)

            # Remove overflow-related keys before tensor conversion if return_tensors is set
            # Slow tokenizers don't support returning these as tensors
            if return_tensors and return_overflowing_tokens:
                batch_outputs.pop("overflowing_tokens", None)
                batch_outputs.pop("num_truncated_tokens", None)

            batch_outputs = self.pad(
                batch_outputs,
                padding=padding_strategy.value,
                max_length=max_length,
                pad_to_multiple_of=pad_to_multiple_of,
                padding_side=padding_side,
                return_attention_mask=return_attention_mask,
            )

            return BatchEncoding(batch_outputs, tensor_type=return_tensors)

        # Single sequence handling
        def get_input_ids(midi_input):
            if not midi_input:
                return []
            if isinstance(midi_input, (str, Path, Score, bytes)):
                if isinstance(midi_input, bytes):
                    midi_input = symusic.Score.from_midi(midi_input)
                return self._tokenizer.encode(midi_input).ids
            if isinstance(midi_input, (list, tuple)) and midi_input:
                if isinstance(midi_input[0], int):
                    return midi_input


            raise ValueError(
                "Input must be a Score, a path to a MIDI file, or a list of token IDs."
            )

        first_ids = get_input_ids(midi)
        second_ids = get_input_ids(midi_pair) if midi_pair is not None else None

        return self.prepare_for_model(
            first_ids,
            pair_ids=second_ids,
            add_special_tokens=add_special_tokens,
            padding=padding_strategy.value,
            truncation=truncation_strategy.value,
            max_length=max_length,
            stride=stride,
            pad_to_multiple_of=pad_to_multiple_of,
            padding_side=padding_side,
            prepend_batch_axis=True,
            return_attention_mask=return_attention_mask,
            return_token_type_ids=return_token_type_ids,
            return_overflowing_tokens=return_overflowing_tokens,
            return_special_tokens_mask=return_special_tokens_mask,
            return_length=return_length,
            verbose=verbose,
        )

    def _decode(
        self,
        token_ids: int | list[int],
        skip_special_tokens: bool = False,
        clean_up_tokenization_spaces: bool | None = None,
        **kwargs,
    ) -> str:
        if isinstance(token_ids, int):
            token_ids = [token_ids]

        tok_sequence = TokSequence(ids=token_ids, are_ids_encoded=True)
        self._tokenizer.decode_token_ids(tok_sequence)

        tokens = [self._decoder[token_id] for token_id in tok_sequence.ids]

        if skip_special_tokens:
            tokens = [
                token for token in tokens if token not in self._tokenizer.special_tokens
            ]

        return " ".join(tokens)
    
    def decode_score(
        self,
        token_ids: int | list[int],
        skip_special_tokens: bool = False,
        clean_up_tokenization_spaces: bool | None = None,
        **kwargs,
    ) -> Score:
        if isinstance(token_ids, int):
            token_ids = [token_ids]

        tok_sequence = TokSequence(ids=token_ids, are_ids_encoded=True)
        return self._tokenizer.decode(tok_sequence)
    
    def save_vocabulary(
        self, save_directory: str, filename_prefix: str | None = None
    ) -> tuple[str, ...]:
        """Save the MidiTok tokenizer params to disk."""
        if not os.path.isdir(save_directory):
            return ()

        prefix = f"{filename_prefix}-" if filename_prefix else ""
        vocab_file = os.path.join(save_directory, prefix + "vocab.json")

        # Use MidiTok's own serialization
        self._tokenizer.save(vocab_file)

        return (vocab_file,)