Commit
·
ae3131d
1
Parent(s):
2171a21
Add generate only continuation
Browse files- gen_res/0.json +0 -1
- gen_res/0.mid +0 -0
- generate_on_one_track.py +105 -0
gen_res/0.json
DELETED
|
@@ -1 +0,0 @@
|
|
| 1 |
-
{"ids": [[155, 32, 112, 122, 158, 37, 112, 121, 161, 41, 597, 163, 25, 469, 166, 36, 113, 122, 184, 25, 113, 119, 141, 25, 441, 145, 25, 113, 120, 147, 25, 111, 126, 29, 111, 130, 149, 27, 597, 154, 30, 221, 157, 25, 111, 124, 161, 25, 112, 121, 166, 29, 822, 308, 32, 112, 131, 149, 30, 112, 125, 159, 32, 180, 160, 29, 441, 166, 29, 322, 190, 32, 111, 123, 143, 25, 110, 130, 152, 37, 221, 156, 25, 113, 122, 158, 25, 113, 119, 161, 13, 322, 166, 30, 112, 121, 168, 25, 112, 122, 184, 37, 111, 130, 37, 112, 129, 25, 822, 143, 25, 192, 144, 25, 180, 148, 25, 378, 150, 25, 378, 153, 25, 192, 156, 25, 112, 123, 25, 279, 161, 25, 378, 165, 25, 112, 125, 184, 25, 112, 122, 25, 112, 124, 25, 113, 119, 190, 25, 226, 139, 25, 180, 140, 25, 180, 144, 25, 226, 25, 192, 151, 25, 192, 25, 220, 152, 25, 220, 13, 180, 157, 25, 220, 32, 378, 13, 178, 158, 25, 220, 163, 25, 112, 122, 25, 220, 25, 196, 164, 25, 185, 167, 25, 220, 25, 114, 117, 25, 180, 184, 25, 180, 37, 378, 141, 25, 378, 143, 25, 220, 25, 180, 149, 25, 192, 25, 220, 149, 25, 220, 49, 220, 151, 25, 220, 37, 180, 154, 25, 192, 25, 192, 155, 25, 192, 158, 25, 113, 119, 25, 378, 13, 175, 161, 25, 192, 25, 178, 13, 171, 162, 25, 185, 167, 25, 378, 25, 220, 25, 220, 206, 25, 220, 25, 178, 141, 25, 178, 25, 178, 144, 25, 180, 25, 178, 146, 25, 175, 25, 191, 148, 25, 177, 25, 175, 149, 25, 175, 25, 178, 151, 25, 178, 152, 25, 191, 25, 192, 154, 25, 180, 25, 175, 25, 175, 155, 25, 178, 157, 25, 192, 25, 180, 25, 180, 25, 180, 160, 25, 180, 25, 192, 25, 180, 161, 25, 178, 25, 180, 163, 25, 378, 25, 378, 25, 175, 25, 185, 168, 25, 220, 25, 220, 25, 180, 32, 191, 144, 25, 220, 25, 192, 25, 181, 25, 181, 146, 13, 173, 24, 175, 149, 25, 185, 27, 181, 19, 175, 151, 24, 180, 25, 175, 32, 178, 17, 174, 155, 24, 180, 20, 178, 25, 178, 27, 178, 27, 180, 154, 31, 192, 27, 180, 27, 180, 20, 178, 24, 175, 144, 24, 192, 20, 180, 27, 180, 27, 178, 145, 31, 180, 19, 178, 20, 178, 31, 178, 148, 31, 378, 19, 175, 24, 174, 32, 192, 151, 31, 180, 19, 180, 31, 192, 31, 180, 154, 27, 178, 32, 378, 23, 178, 25, 180, 155, 27, 192, 26, 180, 19, 175, 19, 220, 156, 19, 175, 19, 178, 157, 31, 378, 43, 192, 31, 192, 19, 180, 159, 25, 192, 31, 192, 19, 178, 31, 180, 160, 19, 178, 43, 192, 27, 178, 161, 26, 180, 26, 180, 162, 19, 180, 31, 180, 19, 192, 31, 180, 163, 31, 192, 19, 180, 19, 175, 31, 178, 19, 175, 31, 220, 165], [2, 184, 169, 141, 20, 113, 118, 22, 113, 118, 29, 113, 118, 38, 113, 118, 143, 17, 441, 24, 919, 33, 919, 40, 919, 147, 21, 113, 118, 28, 114, 118, 30, 113, 118, 39, 113, 118, 149, 16, 112, 128, 18, 111, 128, 25, 111, 128, 34, 111, 128, 165, 36, 219, 184, 13, 187, 139, 20, 179, 141, 27, 188, 143, 34, 196, 153, 20, 109, 132, 28, 211, 33, 211, 156, 30, 205, 35, 205, 158, 32, 205, 37, 211, 161, 27, 222, 34, 227, 165, 25, 222, 29, 222, 36, 235, 184, 26, 111, 128, 31, 111, 128, 153, 33, 407], [2, 184, 169, 141, 20, 113, 118, 22, 113, 118, 29, 113, 118, 38, 113, 118, 143, 17, 441, 24, 919, 33, 919, 40, 919, 147, 21, 113, 118, 28, 114, 118, 30, 113, 118, 39, 113, 118, 149, 16, 112, 128, 18, 111, 128, 25, 111, 128, 34, 111, 128, 165, 36, 219, 184, 13, 187, 139, 20, 179, 141, 27, 188, 143, 34, 196, 153, 20, 109, 132, 28, 211, 33, 211, 156, 30, 205, 35, 205, 158, 32, 205, 37, 211, 161, 27, 222, 34, 227, 165, 25, 222, 29, 222, 36, 235, 184, 26, 111, 128, 31, 111, 128, 153, 33, 407, 155, 32, 112, 122, 158, 37, 112, 121, 161, 41, 597, 163, 25, 469, 166, 36, 113, 122, 184, 25, 113, 119, 141, 25, 441, 145, 25, 113, 120, 147, 25, 111, 126, 29, 111, 130, 149, 27, 597, 154, 30, 221, 157, 25, 111, 124, 161, 25, 112, 121, 166, 29, 822, 308, 32, 112, 131, 149, 30, 112, 125, 159, 32, 180, 160, 29, 441, 166, 29, 322, 190, 32, 111, 123, 143, 25, 110, 130, 152, 37, 221, 156, 25, 113, 122, 158, 25, 113, 119, 161, 13, 322, 166, 30, 112, 121, 168, 25, 112, 122, 184, 37, 111, 130, 37, 112, 129, 25, 822, 143, 25, 192, 144, 25, 180, 148, 25, 378, 150, 25, 378, 153, 25, 192, 156, 25, 112, 123, 25, 279, 161, 25, 378, 165, 25, 112, 125, 184, 25, 112, 122, 25, 112, 124, 25, 113, 119, 190, 25, 226, 139, 25, 180, 140, 25, 180, 144, 25, 226, 25, 192, 151, 25, 192, 25, 220, 152, 25, 220, 13, 180, 157, 25, 220, 32, 378, 13, 178, 158, 25, 220, 163, 25, 112, 122, 25, 220, 25, 196, 164, 25, 185, 167, 25, 220, 25, 114, 117, 25, 180, 184, 25, 180, 37, 378, 141, 25, 378, 143, 25, 220, 25, 180, 149, 25, 192, 25, 220, 149, 25, 220, 49, 220, 151, 25, 220, 37, 180, 154, 25, 192, 25, 192, 155, 25, 192, 158, 25, 113, 119, 25, 378, 13, 175, 161, 25, 192, 25, 178, 13, 171, 162, 25, 185, 167, 25, 378, 25, 220, 25, 220, 206, 25, 220, 25, 178, 141, 25, 178, 25, 178, 144, 25, 180, 25, 178, 146, 25, 175, 25, 191, 148, 25, 177, 25, 175, 149, 25, 175, 25, 178, 151, 25, 178, 152, 25, 191, 25, 192, 154, 25, 180, 25, 175, 25, 175, 155, 25, 178, 157, 25, 192, 25, 180, 25, 180, 25, 180, 160, 25, 180, 25, 192, 25, 180, 161, 25, 178, 25, 180, 163, 25, 378, 25, 378, 25, 175, 25, 185, 168, 25, 220, 25, 220, 25, 180, 32, 191, 144, 25, 220, 25, 192, 25, 181, 25, 181, 146, 13, 173, 24, 175, 149, 25, 185, 27, 181, 19, 175, 151, 24, 180, 25, 175, 32, 178, 17, 174, 155, 24, 180, 20, 178, 25, 178, 27, 178, 27, 180, 154, 31, 192, 27, 180, 27, 180, 20, 178, 24, 175, 144, 24, 192, 20, 180, 27, 180, 27, 178, 145, 31, 180, 19, 178, 20, 178, 31, 178, 148, 31, 378, 19, 175, 24, 174, 32, 192, 151, 31, 180, 19, 180, 31, 192, 31, 180, 154, 27, 178, 32, 378, 23, 178, 25, 180, 155, 27, 192, 26, 180, 19, 175, 19, 220, 156, 19, 175, 19, 178, 157, 31, 378, 43, 192, 31, 192, 19, 180, 159, 25, 192, 31, 192, 19, 178, 31, 180, 160, 19, 178, 43, 192, 27, 178, 161, 26, 180, 26, 180, 162, 19, 180, 31, 180, 19, 192, 31, 180, 163, 31, 192, 19, 180, 19, 175, 31, 178, 19, 175, 31, 220, 165]]}
|
|
|
|
|
|
gen_res/0.mid
DELETED
|
Binary file (3.32 kB)
|
|
|
generate_on_one_track.py
ADDED
|
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from copy import deepcopy
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
from random import shuffle
|
| 4 |
+
|
| 5 |
+
from torch import Tensor, argmax
|
| 6 |
+
from torch.utils.data import DataLoader
|
| 7 |
+
from torch.cuda import is_available as cuda_available, is_bf16_supported
|
| 8 |
+
from torch.backends.mps import is_available as mps_available
|
| 9 |
+
from transformers import AutoModelForCausalLM, MistralConfig, Trainer, TrainingArguments, GenerationConfig, AutoTokenizer, MistralForCausalLM
|
| 10 |
+
from transformers.trainer_utils import set_seed
|
| 11 |
+
from evaluate import load as load_metric
|
| 12 |
+
from miditok import REMI, TokenizerConfig
|
| 13 |
+
from miditok.pytorch_data import DatasetTok, DataCollator
|
| 14 |
+
from tqdm import tqdm
|
| 15 |
+
|
| 16 |
+
# Our tokenizer's configuration
|
| 17 |
+
PITCH_RANGE = (21, 109)
|
| 18 |
+
BEAT_RES = {(0, 1): 8, (1, 2): 4, (2, 4): 2, (4, 8): 1}
|
| 19 |
+
NUM_VELOCITIES = 24
|
| 20 |
+
SPECIAL_TOKENS = ["PAD", "MASK", "BOS", "EOS"]
|
| 21 |
+
USE_CHORDS = False
|
| 22 |
+
USE_RESTS = False
|
| 23 |
+
USE_TEMPOS = True
|
| 24 |
+
USE_TIME_SIGNATURE = False
|
| 25 |
+
USE_PROGRAMS = False
|
| 26 |
+
NUM_TEMPOS = 32
|
| 27 |
+
TEMPO_RANGE = (50, 200) # (min_tempo, max_tempo)
|
| 28 |
+
TOKENIZER_PARAMS = {
|
| 29 |
+
"pitch_range": PITCH_RANGE,
|
| 30 |
+
"beat_res": BEAT_RES,
|
| 31 |
+
"num_velocities": NUM_VELOCITIES,
|
| 32 |
+
"special_tokens": SPECIAL_TOKENS,
|
| 33 |
+
"use_chords": USE_CHORDS,
|
| 34 |
+
"use_rests": USE_RESTS,
|
| 35 |
+
"use_tempos": USE_TEMPOS,
|
| 36 |
+
"use_time_signatures": USE_TIME_SIGNATURE,
|
| 37 |
+
"use_programs": USE_PROGRAMS,
|
| 38 |
+
"num_tempos": NUM_TEMPOS,
|
| 39 |
+
"tempo_range": TEMPO_RANGE,
|
| 40 |
+
}
|
| 41 |
+
config = TokenizerConfig(**TOKENIZER_PARAMS)
|
| 42 |
+
|
| 43 |
+
# Seed
|
| 44 |
+
set_seed(777)
|
| 45 |
+
|
| 46 |
+
# Creates the tokenizer
|
| 47 |
+
tokenizer = REMI.from_pretrained("sunsetsobserver/MIDI")
|
| 48 |
+
|
| 49 |
+
midi_paths = list(Path('input').glob('**/*.mid')) + list(Path('input').glob('**/*.midi'))
|
| 50 |
+
|
| 51 |
+
""" list(Path('Maestro').glob('**/*.mid')) + list(Path('Maestro').glob('**/*.midi')) """
|
| 52 |
+
|
| 53 |
+
# Loads tokens and create data collator
|
| 54 |
+
kwargs_dataset = {"min_seq_len": 10, "max_seq_len": 1024, "tokenizer": tokenizer}
|
| 55 |
+
dataset_test = DatasetTok(midi_paths, **kwargs_dataset)
|
| 56 |
+
collator = DataCollator(
|
| 57 |
+
tokenizer["PAD_None"], tokenizer["BOS_None"], tokenizer["EOS_None"]
|
| 58 |
+
)
|
| 59 |
+
|
| 60 |
+
# Creates model using the correct configuration
|
| 61 |
+
model = MistralForCausalLM.from_pretrained("./runs")
|
| 62 |
+
|
| 63 |
+
collator = DataCollator(tokenizer["PAD_None"], tokenizer["BOS_None"], tokenizer["EOS_None"], copy_inputs_as_labels=True)
|
| 64 |
+
|
| 65 |
+
(gen_results_path := Path('gen_res')).mkdir(parents=True, exist_ok=True)
|
| 66 |
+
generation_config = GenerationConfig(
|
| 67 |
+
max_new_tokens=512, # extends samples by 512 tokens
|
| 68 |
+
num_beams=1, # no beam search
|
| 69 |
+
do_sample=True, # but sample instead
|
| 70 |
+
temperature=0.9,
|
| 71 |
+
top_k=15,
|
| 72 |
+
top_p=0.95,
|
| 73 |
+
epsilon_cutoff=3e-4,
|
| 74 |
+
eta_cutoff=1e-3,
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
# Here the sequences are padded to the left, so that the last token along the time dimension
|
| 78 |
+
# is always the last token of each seq, allowing to efficiently generate by batch
|
| 79 |
+
collator.pad_on_left = True
|
| 80 |
+
collator.eos_token = None
|
| 81 |
+
dataloader_test = DataLoader(dataset_test, batch_size=1, collate_fn=collator)
|
| 82 |
+
model.eval()
|
| 83 |
+
count = 0
|
| 84 |
+
for batch in tqdm(dataloader_test, desc='Testing model / Generating results'): # (N,T)
|
| 85 |
+
res = model.generate(
|
| 86 |
+
inputs=batch["input_ids"].to(model.device),
|
| 87 |
+
attention_mask=batch["attention_mask"].to(model.device),
|
| 88 |
+
generation_config=generation_config) # (N,T)
|
| 89 |
+
|
| 90 |
+
# Saves the generated music, as MIDI files and tokens (json)
|
| 91 |
+
for prompt, continuation in zip(batch["input_ids"], res):
|
| 92 |
+
# Generate the MIDI for the entire sequence (prompt + continuation)
|
| 93 |
+
midi = tokenizer.tokens_to_midi([deepcopy(continuation.tolist())])
|
| 94 |
+
|
| 95 |
+
# Set the track name to indicate it includes both the original and the continuation
|
| 96 |
+
midi.tracks[0].name = f'Original sample and continuation ({len(continuation)} tokens)'
|
| 97 |
+
|
| 98 |
+
# Dump the MIDI file for the combined prompt and continuation
|
| 99 |
+
midi.dump_midi(gen_results_path / f'{count}.mid')
|
| 100 |
+
|
| 101 |
+
# Optionally, save the tokens for the combined sequence
|
| 102 |
+
tokens = [continuation.tolist()] # This time, only saving the combined sequence
|
| 103 |
+
tokenizer.save_tokens(tokens, gen_results_path / f'{count}.json')
|
| 104 |
+
|
| 105 |
+
count += 1
|