Spaces:
Running
Running
| # -*- coding: utf-8 -*- | |
| """miditokenizer.ipynb | |
| Automatically generated by Colab. | |
| Original file is located at | |
| https://colab.research.google.com/drive/16YJUBYcqKYPVIhwzKNi4ELnTftr2TcUY | |
| """ | |
| #We use a base GPT2 tokenizer with additional functions to handle composer tokens | |
| #Datasets are created by processing our files in chunks, due to model sequence limits | |
| #Position information is added to each chunk as additional pattern/data for training | |
| from transformers import GPT2TokenizerFast, GPT2LMHeadModel | |
| from torch.utils.data import Dataset | |
| from pathlib import Path | |
| import torch | |
| class MIDITokenizer: | |
| """tokenization specific to MIDI data with special tokens""" | |
| def __init__(self, pretrained_model='gpt2'): | |
| self.base_tokenizer = GPT2TokenizerFast.from_pretrained(pretrained_model) | |
| special_tokens = { | |
| 'additional_special_tokens': [ | |
| '<|START_METADATA|>', | |
| '<|END_METADATA|>', | |
| '<|START_TRACK|>', | |
| '<|END_TRACK|>', | |
| '<metadata>', | |
| '<tempo>', | |
| '<time_signature>', | |
| '<program_change>', | |
| '<note_on>', | |
| '<note_off>', | |
| '<control_change>' | |
| ], | |
| 'pad_token': '[PAD]' | |
| } | |
| self.base_tokenizer.add_special_tokens(special_tokens) | |
| self.pad_token_id = self.base_tokenizer.pad_token_id | |
| self.eos_token_id = self.base_tokenizer.eos_token_id | |
| self.bos_token_id = self.base_tokenizer.bos_token_id | |
| self.pad_token = self.base_tokenizer.pad_token | |
| self.eos_token = self.base_tokenizer.eos_token | |
| self.bos_token = self.base_tokenizer.bos_token | |
| def add_composer_tokens(self, composers): | |
| #composer tokens | |
| composer_tokens = [f'<|composer_{c}|>' for c in composers] | |
| self.base_tokenizer.add_special_tokens({ | |
| 'additional_special_tokens': composer_tokens | |
| }) | |
| def __call__(self, text, **kwargs): | |
| return self.base_tokenizer(text, **kwargs) | |
| def decode(self, token_ids, **kwargs): | |
| """Decode while preserving special tokens""" | |
| return self.base_tokenizer.decode(token_ids, skip_special_tokens=False, **kwargs) | |
| def pad(self, *args, **kwargs): | |
| return self.base_tokenizer.pad(*args, **kwargs) | |
| def get_vocab(self): | |
| return self.base_tokenizer.get_vocab() |