TextToAudio / text2midi_repo /model /data_loader_remi.py
kmaes's picture
Upload 27 files
b148e11 verified
import yaml
import jsonlines
import glob
import random
import os
import sys
import pickle
import json
import argparse
import numpy as np
from copy import deepcopy
from torch.utils.data import Dataset
import torch
from torch.nn import functional as F
from transformers import T5Tokenizer
from spacy.lang.en import English
SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
sys.path.append(os.path.dirname(SCRIPT_DIR))
class Text2MusicDataset(Dataset):
def __init__(self, configs, captions, remi_tokenizer, mode="train", shuffle = False):
self.mode = mode
self.captions = captions
if shuffle:
random.shuffle(self.captions)
# Path to dataset
self.dataset_path = configs['raw_data']['raw_data_folders']['midicaps']['folder_path']
# Artifact folder
self.artifact_folder = configs['artifact_folder']
# Load encoder tokenizer json file dictionary
# tokenizer_filepath = os.path.join(self.artifact_folder, "vocab.pkl")
# Load the pickled tokenizer dictionary
# with open(tokenizer_filepath, 'rb') as f:
# self.tokenizer = pickle.load(f)
self.remi_tokenizer = remi_tokenizer
# Load the sentencizer
self.nlp = English()
self.nlp.add_pipe('sentencizer')
# Load the FLAN-T5 tokenizer and encoder
self.t5_tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-base")
# Get the maximum sequence length
self.decoder_max_sequence_length = configs['model']['text2midi_model']['decoder_max_sequence_length']
# Print length of dataset
print("Length of dataset: ", len(self.captions))
def __len__(self):
return len(self.captions)
def __getitem__(self, idx):
caption = self.captions[idx]['caption']
midi_filepath = os.path.join(self.dataset_path, self.captions[idx]['location'])
# print(f'midi filepath: {midi_filepath}')
# Read the MIDI file
tokens = self.remi_tokenizer(midi_filepath)
if len(tokens.ids) == 0:
tokenized_midi = [self.remi_tokenizer["BOS_None"], self.remi_tokenizer["EOS_None"]]
else:
tokenized_midi = [self.remi_tokenizer["BOS_None"]] + tokens.ids + [self.remi_tokenizer["EOS_None"]]
# Drop a random number of sentences from the caption
do_drop = random.random() > 0.5
if do_drop:
sentences = list(self.nlp(caption).sents)
sent_length = len(sentences)
if sent_length<4:
how_many_to_drop = int(np.floor((20 + random.random()*30)/100*sent_length)) # between 20 and 50 percent of sentences
else:
how_many_to_drop = int(np.ceil((20 + random.random()*30)/100*sent_length)) # between 20 and 50 percent of sentences
which_to_drop = np.random.choice(sent_length, how_many_to_drop, replace=False)
new_sentences = [sentences[i] for i in range(sent_length) if i not in which_to_drop.tolist()]
new_sentences = " ".join([new_sentences[i].text for i in range(len(new_sentences))]) # combine sentences back with a space
else:
new_sentences = caption
# Tokenize the caption
inputs = self.t5_tokenizer(new_sentences, return_tensors='pt', padding=True, truncation=True)
input_ids = inputs['input_ids']
attention_mask = inputs['attention_mask']
# Convert the tokenized MIDI file to a tensor and pad it to the maximum sequence length
if len(tokenized_midi) < self.decoder_max_sequence_length:
labels = F.pad(torch.tensor(tokenized_midi), (0, self.decoder_max_sequence_length - len(tokenized_midi))).to(torch.int64)
else:
labels = torch.tensor(tokenized_midi[0:self.decoder_max_sequence_length]).to(torch.int64)
return input_ids, attention_mask, labels
if __name__ == "__main__":
# Parse command line arguments
parser = argparse.ArgumentParser()
parser.add_argument("--config", type=str, default=os.path.normpath("../configs/config.yaml"),
help="Path to the config file")
args = parser.parse_args()
tokenizer_filepath = "../artifacts/vocab_remi.pkl"
# Load the tokenizer dictionary
with open(tokenizer_filepath, "rb") as f:
tokenizer = pickle.load(f)
bos_token_number = tokenizer["PAD_None"]
print(f"bos_token_number: {bos_token_number}")
# Load config file
with open(args.config, 'r') as f:
configs = yaml.safe_load(f)
caption_dataset_path = configs['raw_data']['caption_dataset_path']
# Load the caption dataset
with jsonlines.open(caption_dataset_path) as reader:
captions = list(reader)
# Load the dataset
dataset = Text2MusicDataset(configs, captions, remi_tokenizer=tokenizer, mode="train", shuffle = True)
a,b,c = dataset[0]
print(type(a))
generated_midi = tokenizer.decode(c)
print(type(generated_midi))
generated_midi.dump_midi("decoded_midi.mid")