Instructions to use B-K/song2midi-processor with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use B-K/song2midi-processor with Transformers:
# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("B-K/song2midi-processor", dtype="auto") - Notebooks
- Google Colab
- Kaggle
Upload tokenization_song2midi.py
Browse files- tokenization_song2midi.py +25 -5
tokenization_song2midi.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
| 1 |
import os
|
| 2 |
from pathlib import Path
|
| 3 |
-
from typing import Union
|
| 4 |
|
| 5 |
from transformers import BatchEncoding, PythonBackend
|
| 6 |
from transformers.tokenization_utils_base import TruncationStrategy
|
|
@@ -8,7 +8,9 @@ from transformers.utils.generic import PaddingStrategy, TensorType
|
|
| 8 |
|
| 9 |
try:
|
| 10 |
from miditok import PerTok, TokSequence
|
| 11 |
-
|
|
|
|
|
|
|
| 12 |
except ImportError:
|
| 13 |
raise ImportError(
|
| 14 |
"The `miditok` library is required for processing MIDI files. "
|
|
@@ -29,6 +31,11 @@ class Song2MIDIPerTokTokenizer(PythonBackend):
|
|
| 29 |
**kwargs,
|
| 30 |
):
|
| 31 |
self._tokenizer = PerTok(params=vocab_file)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 32 |
|
| 33 |
self._decoder = {value: key for key, value in self._tokenizer.vocab.items()}
|
| 34 |
|
|
@@ -49,8 +56,8 @@ class Song2MIDIPerTokTokenizer(PythonBackend):
|
|
| 49 |
|
| 50 |
def _encode_plus(
|
| 51 |
self,
|
| 52 |
-
text: Union[
|
| 53 |
-
text_pair: Union[
|
| 54 |
add_special_tokens: bool = True,
|
| 55 |
padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
|
| 56 |
truncation_strategy: TruncationStrategy = TruncationStrategy.DO_NOT_TRUNCATE,
|
|
@@ -132,7 +139,7 @@ class Song2MIDIPerTokTokenizer(PythonBackend):
|
|
| 132 |
return []
|
| 133 |
if isinstance(midi_input, (str, Path, Score, bytes)):
|
| 134 |
if isinstance(midi_input, bytes):
|
| 135 |
-
midi_input = Score.from_midi(midi_input)
|
| 136 |
return self._tokenizer.encode(midi_input).ids
|
| 137 |
if isinstance(midi_input, (list, tuple)) and midi_input:
|
| 138 |
if isinstance(midi_input[0], int):
|
|
@@ -186,7 +193,20 @@ class Song2MIDIPerTokTokenizer(PythonBackend):
|
|
| 186 |
]
|
| 187 |
|
| 188 |
return " ".join(tokens)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 189 |
|
|
|
|
|
|
|
|
|
|
| 190 |
def save_vocabulary(
|
| 191 |
self, save_directory: str, filename_prefix: str | None = None
|
| 192 |
) -> tuple[str, ...]:
|
|
|
|
| 1 |
import os
|
| 2 |
from pathlib import Path
|
| 3 |
+
from typing import Union, TYPE_CHECKING
|
| 4 |
|
| 5 |
from transformers import BatchEncoding, PythonBackend
|
| 6 |
from transformers.tokenization_utils_base import TruncationStrategy
|
|
|
|
| 8 |
|
| 9 |
try:
|
| 10 |
from miditok import PerTok, TokSequence
|
| 11 |
+
import symusic
|
| 12 |
+
if TYPE_CHECKING:
|
| 13 |
+
from symusic.types import Score
|
| 14 |
except ImportError:
|
| 15 |
raise ImportError(
|
| 16 |
"The `miditok` library is required for processing MIDI files. "
|
|
|
|
| 31 |
**kwargs,
|
| 32 |
):
|
| 33 |
self._tokenizer = PerTok(params=vocab_file)
|
| 34 |
+
|
| 35 |
+
# PerTok as of miditok version 3.0.6.post1 does not load position token locations from the vocab file.
|
| 36 |
+
# use_position_toks workaround
|
| 37 |
+
if self._tokenizer.use_position_toks and not getattr(self._tokenizer, "position_locations", None):
|
| 38 |
+
self._tokenizer.position_locations = self._tokenizer._create_position_tok_locations()
|
| 39 |
|
| 40 |
self._decoder = {value: key for key, value in self._tokenizer.vocab.items()}
|
| 41 |
|
|
|
|
| 56 |
|
| 57 |
def _encode_plus(
|
| 58 |
self,
|
| 59 |
+
text: Union[Score, Path, bytes, list[Union[Score, Path, bytes]], list[int]],
|
| 60 |
+
text_pair: Union[Score, Path, list[Union[Score, Path]], list[int], None] = None,
|
| 61 |
add_special_tokens: bool = True,
|
| 62 |
padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
|
| 63 |
truncation_strategy: TruncationStrategy = TruncationStrategy.DO_NOT_TRUNCATE,
|
|
|
|
| 139 |
return []
|
| 140 |
if isinstance(midi_input, (str, Path, Score, bytes)):
|
| 141 |
if isinstance(midi_input, bytes):
|
| 142 |
+
midi_input = symusic.Score.from_midi(midi_input)
|
| 143 |
return self._tokenizer.encode(midi_input).ids
|
| 144 |
if isinstance(midi_input, (list, tuple)) and midi_input:
|
| 145 |
if isinstance(midi_input[0], int):
|
|
|
|
| 193 |
]
|
| 194 |
|
| 195 |
return " ".join(tokens)
|
| 196 |
+
|
| 197 |
+
def decode_score(
|
| 198 |
+
self,
|
| 199 |
+
token_ids: int | list[int],
|
| 200 |
+
skip_special_tokens: bool = False,
|
| 201 |
+
clean_up_tokenization_spaces: bool | None = None,
|
| 202 |
+
**kwargs,
|
| 203 |
+
) -> Score:
|
| 204 |
+
if isinstance(token_ids, int):
|
| 205 |
+
token_ids = [token_ids]
|
| 206 |
|
| 207 |
+
tok_sequence = TokSequence(ids=token_ids, are_ids_encoded=True)
|
| 208 |
+
return self._tokenizer.decode(tok_sequence)
|
| 209 |
+
|
| 210 |
def save_vocabulary(
|
| 211 |
self, save_directory: str, filename_prefix: str | None = None
|
| 212 |
) -> tuple[str, ...]:
|