Spaces:
Sleeping
Sleeping
File size: 5,024 Bytes
b148e11 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 |
from aria.data.midi import MidiDict
# from aria.tokenizer import AbsTokenizer
# aria_tokenizer = AbsTokenizer()
import yaml
import jsonlines
import glob
import random
import os
import sys
import pickle
import json
import argparse
import numpy as np
from copy import deepcopy
from torch.utils.data import Dataset
import torch
from torch.nn import functional as F
from transformers import T5Tokenizer
from spacy.lang.en import English
SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
sys.path.append(os.path.dirname(SCRIPT_DIR))
class Text2MusicDataset(Dataset):
def __init__(self, configs, captions, aria_tokenizer, mode="train", shuffle = False):
self.mode = mode
self.captions = captions
if shuffle:
random.shuffle(self.captions)
# Path to dataset
self.dataset_path = configs['raw_data']['raw_data_folders']['midicaps']['folder_path']
# Artifact folder
self.artifact_folder = configs['artifact_folder']
# Load encoder tokenizer json file dictionary
tokenizer_filepath = os.path.join(self.artifact_folder, "vocab.pkl")
self.aria_tokenizer = aria_tokenizer #AbsTokenizer()
# Load the pickled tokenizer dictionary
with open(tokenizer_filepath, 'rb') as f:
self.tokenizer = pickle.load(f)
# Load the sentencizer
self.nlp = English()
self.nlp.add_pipe('sentencizer')
# Load the FLAN-T5 tokenizer and encoder
self.t5_tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-base")
# Get the maximum sequence length
self.decoder_max_sequence_length = configs['model']['text2midi_model']['decoder_max_sequence_length']
# Print length of dataset
print("Length of dataset: ", len(self.captions))
def __len__(self):
return len(self.captions)
def __getitem__(self, idx):
caption = self.captions[idx]['caption']
midi_filepath = os.path.join(self.dataset_path, self.captions[idx]['location'])
# Read the MIDI file
midi = MidiDict.from_midi(midi_filepath)
if len(midi.note_msgs) == 0:
aria_tokenized_midi = ["<SS>", "<E>"]
else:
# Get the tokenized MIDI file
aria_tokenized_midi = self.aria_tokenizer.tokenize(midi)
# Add the start token
aria_tokenized_midi = ["<SS>"] + aria_tokenized_midi
# Drop a random number of sentences from the caption
do_drop = random.random() > 0.5
if do_drop:
sentences = list(self.nlp(caption).sents)
sent_length = len(sentences)
if sent_length<4:
how_many_to_drop = int(np.floor((20 + random.random()*30)/100*sent_length)) # between 20 and 50 percent of sentences
else:
how_many_to_drop = int(np.ceil((20 + random.random()*30)/100*sent_length)) # between 20 and 50 percent of sentences
which_to_drop = np.random.choice(sent_length, how_many_to_drop, replace=False)
new_sentences = [sentences[i] for i in range(sent_length) if i not in which_to_drop.tolist()]
new_sentences = " ".join([new_sentences[i].text for i in range(len(new_sentences))]) # combine sentences back with a space
else:
new_sentences = caption
# Tokenize the caption
inputs = self.t5_tokenizer(new_sentences, return_tensors='pt', padding=True, truncation=True)
input_ids = inputs['input_ids']
attention_mask = inputs['attention_mask']
# Tokenize the midi file
tokenized_midi = [self.tokenizer[token] for token in aria_tokenized_midi if token in self.tokenizer]
# Convert the tokenized MIDI file to a tensor and pad it to the maximum sequence length
if len(tokenized_midi) < self.decoder_max_sequence_length:
labels = F.pad(torch.tensor(tokenized_midi), (0, self.decoder_max_sequence_length - len(tokenized_midi))).to(torch.int64)
else:
labels = torch.tensor(tokenized_midi[-self.decoder_max_sequence_length:]).to(torch.int64)
return input_ids, attention_mask, labels
if __name__ == "__main__":
# Parse command line arguments
parser = argparse.ArgumentParser()
parser.add_argument("--config", type=str, default=os.path.normpath("../configs/config.yaml"),
help="Path to the config file")
args = parser.parse_args()
# Load config file
with open(args.config, 'r') as f:
configs = yaml.safe_load(f)
caption_dataset_path = configs['raw_data']['caption_dataset_path']
# Load the caption dataset
with jsonlines.open(caption_dataset_path) as reader:
captions = list(reader)
# Load the dataset
dataset = Text2MusicDataset(configs, captions, mode="train", shuffle = True)
a,b,c = dataset[0]
print(c.shape) |