B-K commited on
Commit
24662af
·
verified ·
1 Parent(s): 78277d7

Upload tokenization_song2midi.py

Browse files
Files changed (1) hide show
  1. tokenization_song2midi.py +25 -5
tokenization_song2midi.py CHANGED
@@ -1,6 +1,6 @@
1
  import os
2
  from pathlib import Path
3
- from typing import Union
4
 
5
  from transformers import BatchEncoding, PythonBackend
6
  from transformers.tokenization_utils_base import TruncationStrategy
@@ -8,7 +8,9 @@ from transformers.utils.generic import PaddingStrategy, TensorType
8
 
9
  try:
10
  from miditok import PerTok, TokSequence
11
- from symusic import Score
 
 
12
  except ImportError:
13
  raise ImportError(
14
  "The `miditok` library is required for processing MIDI files. "
@@ -29,6 +31,11 @@ class Song2MIDIPerTokTokenizer(PythonBackend):
29
  **kwargs,
30
  ):
31
  self._tokenizer = PerTok(params=vocab_file)
 
 
 
 
 
32
 
33
  self._decoder = {value: key for key, value in self._tokenizer.vocab.items()}
34
 
@@ -49,8 +56,8 @@ class Song2MIDIPerTokTokenizer(PythonBackend):
49
 
50
  def _encode_plus(
51
  self,
52
- text: Union["Score", Path, bytes, list[Union["Score", Path, bytes]], list[int]],
53
- text_pair: Union["Score", Path, list[Union["Score", Path]], list[int], None] = None,
54
  add_special_tokens: bool = True,
55
  padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
56
  truncation_strategy: TruncationStrategy = TruncationStrategy.DO_NOT_TRUNCATE,
@@ -132,7 +139,7 @@ class Song2MIDIPerTokTokenizer(PythonBackend):
132
  return []
133
  if isinstance(midi_input, (str, Path, Score, bytes)):
134
  if isinstance(midi_input, bytes):
135
- midi_input = Score.from_midi(midi_input)
136
  return self._tokenizer.encode(midi_input).ids
137
  if isinstance(midi_input, (list, tuple)) and midi_input:
138
  if isinstance(midi_input[0], int):
@@ -186,7 +193,20 @@ class Song2MIDIPerTokTokenizer(PythonBackend):
186
  ]
187
 
188
  return " ".join(tokens)
 
 
 
 
 
 
 
 
 
 
189
 
 
 
 
190
  def save_vocabulary(
191
  self, save_directory: str, filename_prefix: str | None = None
192
  ) -> tuple[str, ...]:
 
1
  import os
2
  from pathlib import Path
3
+ from typing import Union, TYPE_CHECKING
4
 
5
  from transformers import BatchEncoding, PythonBackend
6
  from transformers.tokenization_utils_base import TruncationStrategy
 
8
 
9
  try:
10
  from miditok import PerTok, TokSequence
11
+ import symusic
12
+ if TYPE_CHECKING:
13
+ from symusic.types import Score
14
  except ImportError:
15
  raise ImportError(
16
  "The `miditok` library is required for processing MIDI files. "
 
31
  **kwargs,
32
  ):
33
  self._tokenizer = PerTok(params=vocab_file)
34
+
35
+ # PerTok as of miditok version 3.0.6.post1 does not load position token locations from the vocab file.
36
+ # use_position_toks workaround
37
+ if self._tokenizer.use_position_toks and not getattr(self._tokenizer, "position_locations", None):
38
+ self._tokenizer.position_locations = self._tokenizer._create_position_tok_locations()
39
 
40
  self._decoder = {value: key for key, value in self._tokenizer.vocab.items()}
41
 
 
56
 
57
  def _encode_plus(
58
  self,
59
+ text: Union[Score, Path, bytes, list[Union[Score, Path, bytes]], list[int]],
60
+ text_pair: Union[Score, Path, list[Union[Score, Path]], list[int], None] = None,
61
  add_special_tokens: bool = True,
62
  padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
63
  truncation_strategy: TruncationStrategy = TruncationStrategy.DO_NOT_TRUNCATE,
 
139
  return []
140
  if isinstance(midi_input, (str, Path, Score, bytes)):
141
  if isinstance(midi_input, bytes):
142
+ midi_input = symusic.Score.from_midi(midi_input)
143
  return self._tokenizer.encode(midi_input).ids
144
  if isinstance(midi_input, (list, tuple)) and midi_input:
145
  if isinstance(midi_input[0], int):
 
193
  ]
194
 
195
  return " ".join(tokens)
196
+
197
+ def decode_score(
198
+ self,
199
+ token_ids: int | list[int],
200
+ skip_special_tokens: bool = False,
201
+ clean_up_tokenization_spaces: bool | None = None,
202
+ **kwargs,
203
+ ) -> Score:
204
+ if isinstance(token_ids, int):
205
+ token_ids = [token_ids]
206
 
207
+ tok_sequence = TokSequence(ids=token_ids, are_ids_encoded=True)
208
+ return self._tokenizer.decode(tok_sequence)
209
+
210
  def save_vocabulary(
211
  self, save_directory: str, filename_prefix: str | None = None
212
  ) -> tuple[str, ...]: