|
|
|
|
|
"""miditokenizer.ipynb |
|
|
|
|
|
Automatically generated by Colab. |
|
|
|
|
|
Original file is located at |
|
|
https://colab.research.google.com/drive/16YJUBYcqKYPVIhwzKNi4ELnTftr2TcUY |
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from transformers import GPT2TokenizerFast, GPT2LMHeadModel |
|
|
from torch.utils.data import Dataset |
|
|
from pathlib import Path |
|
|
import torch |
|
|
|
|
|
class MIDITokenizer: |
|
|
"""tokenization specific to MIDI data with special tokens""" |
|
|
def __init__(self, pretrained_model='gpt2'): |
|
|
self.base_tokenizer = GPT2TokenizerFast.from_pretrained(pretrained_model) |
|
|
special_tokens = { |
|
|
'additional_special_tokens': [ |
|
|
'<|START_METADATA|>', |
|
|
'<|END_METADATA|>', |
|
|
'<|START_TRACK|>', |
|
|
'<|END_TRACK|>', |
|
|
'<metadata>', |
|
|
'<tempo>', |
|
|
'<time_signature>', |
|
|
'<program_change>', |
|
|
'<note_on>', |
|
|
'<note_off>', |
|
|
'<control_change>' |
|
|
], |
|
|
'pad_token': '[PAD]' |
|
|
} |
|
|
self.base_tokenizer.add_special_tokens(special_tokens) |
|
|
self.pad_token_id = self.base_tokenizer.pad_token_id |
|
|
self.eos_token_id = self.base_tokenizer.eos_token_id |
|
|
self.bos_token_id = self.base_tokenizer.bos_token_id |
|
|
self.pad_token = self.base_tokenizer.pad_token |
|
|
self.eos_token = self.base_tokenizer.eos_token |
|
|
self.bos_token = self.base_tokenizer.bos_token |
|
|
|
|
|
|
|
|
def add_composer_tokens(self, composers): |
|
|
|
|
|
composer_tokens = [f'<|composer_{c}|>' for c in composers] |
|
|
self.base_tokenizer.add_special_tokens({ |
|
|
'additional_special_tokens': composer_tokens |
|
|
}) |
|
|
|
|
|
def __call__(self, text, **kwargs): |
|
|
return self.base_tokenizer(text, **kwargs) |
|
|
|
|
|
def decode(self, token_ids, **kwargs): |
|
|
"""Decode while preserving special tokens""" |
|
|
return self.base_tokenizer.decode(token_ids, skip_special_tokens=False, **kwargs) |
|
|
|
|
|
def pad(self, *args, **kwargs): |
|
|
return self.base_tokenizer.pad(*args, **kwargs) |
|
|
|
|
|
def get_vocab(self): |
|
|
return self.base_tokenizer.get_vocab() |