updated generation process - epsilon
Browse files- familizer.py +0 -1
- generate.py +62 -29
- generation_utils.py +35 -6
familizer.py
CHANGED
|
@@ -115,7 +115,6 @@ class Familizer:
|
|
| 115 |
|
| 116 |
|
| 117 |
if __name__ == "__main__":
|
| 118 |
-
|
| 119 |
# Choose number of jobs for parallel processing
|
| 120 |
n_jobs = -1
|
| 121 |
|
|
|
|
| 115 |
|
| 116 |
|
| 117 |
if __name__ == "__main__":
|
|
|
|
| 118 |
# Choose number of jobs for parallel processing
|
| 119 |
n_jobs = -1
|
| 120 |
|
generate.py
CHANGED
|
@@ -1,8 +1,5 @@
|
|
| 1 |
from generation_utils import *
|
| 2 |
-
|
| 3 |
-
from load import LoadModel
|
| 4 |
-
from decoder import TextDecoder
|
| 5 |
-
from playback import get_music
|
| 6 |
|
| 7 |
|
| 8 |
class GenerateMidiText:
|
|
@@ -100,15 +97,26 @@ class GenerateMidiText:
|
|
| 100 |
text = text.rstrip(" ").rstrip("TRACK_END")
|
| 101 |
return text
|
| 102 |
|
| 103 |
-
def get_last_generated_track(self,
|
| 104 |
-
track
|
| 105 |
-
|
| 106 |
-
+ self.striping_track_ends(full_piece.split("TRACK_START ")[-1])
|
| 107 |
-
+ "TRACK_END "
|
| 108 |
-
) # forcing the space after track and
|
| 109 |
return track
|
| 110 |
|
| 111 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 112 |
text = ""
|
| 113 |
for bar in self.piece_by_track[track_id]["bars"]:
|
| 114 |
text += bar
|
|
@@ -122,18 +130,12 @@ class GenerateMidiText:
|
|
| 122 |
def get_whole_piece_from_bar_dict(self):
|
| 123 |
text = "PIECE_START "
|
| 124 |
for track_id, _ in enumerate(self.piece_by_track):
|
| 125 |
-
text += self.
|
| 126 |
return text
|
| 127 |
|
| 128 |
-
def delete_one_track(self, track):
|
| 129 |
self.piece_by_track.pop(track)
|
| 130 |
|
| 131 |
-
# def update_piece_dict__add_track(self, track_id, track):
|
| 132 |
-
# self.piece_dict[track_id] = track
|
| 133 |
-
|
| 134 |
-
# def update_all_dictionnaries__add_track(self, track):
|
| 135 |
-
# self.update_piece_dict__add_track(track_id, track)
|
| 136 |
-
|
| 137 |
"""Basic generation tools"""
|
| 138 |
|
| 139 |
def tokenize_input_prompt(self, input_prompt, verbose=True):
|
|
@@ -238,10 +240,12 @@ class GenerateMidiText:
|
|
| 238 |
)
|
| 239 |
else:
|
| 240 |
print('"--- Wrong length - Regenerating ---')
|
|
|
|
| 241 |
if not bar_count_checks:
|
| 242 |
failed += 1
|
| 243 |
-
|
| 244 |
-
|
|
|
|
| 245 |
|
| 246 |
return full_piece
|
| 247 |
|
|
@@ -298,8 +302,7 @@ class GenerateMidiText:
|
|
| 298 |
|
| 299 |
""" Piece generation - Extra Bars """
|
| 300 |
|
| 301 |
-
|
| 302 |
-
def process_prompt_for_next_bar(self, track_idx):
|
| 303 |
"""Processing the prompt for the model to generate one more bar only.
|
| 304 |
The prompt containts:
|
| 305 |
if not the first bar: the previous, already processed, bars of the track
|
|
@@ -318,6 +321,10 @@ class GenerateMidiText:
|
|
| 318 |
if i != track_idx:
|
| 319 |
len_diff = len(othertrack["bars"]) - len(track["bars"])
|
| 320 |
if len_diff > 0:
|
|
|
|
|
|
|
|
|
|
|
|
|
| 321 |
# if other bars are longer, it mean that this one should catch up
|
| 322 |
pre_promt += othertrack["bars"][0]
|
| 323 |
for bar in track["bars"][-self.model_n_bar :]:
|
|
@@ -325,7 +332,7 @@ class GenerateMidiText:
|
|
| 325 |
pre_promt += "TRACK_END "
|
| 326 |
elif (
|
| 327 |
False
|
| 328 |
-
): # len_diff <= 0: # THIS DOES NOT WORK - It just
|
| 329 |
# adding an empty bars at the end of the other tracks if they have not been processed yet
|
| 330 |
pre_promt += othertracks["bars"][0]
|
| 331 |
for bar in track["bars"][-(self.model_n_bar - 1) :]:
|
|
@@ -337,27 +344,54 @@ class GenerateMidiText:
|
|
| 337 |
# for the bar to prolong
|
| 338 |
# initialization e.g TRACK_START INST=DRUMS DENSITY=2
|
| 339 |
processed_prompt = track["bars"][0]
|
|
|
|
|
|
|
|
|
|
|
|
|
| 340 |
for bar in track["bars"][-(self.model_n_bar - 1) :]:
|
| 341 |
# adding the "last" bars of the track
|
| 342 |
processed_prompt += bar
|
| 343 |
|
| 344 |
processed_prompt += "BAR_START "
|
|
|
|
|
|
|
|
|
|
|
|
|
| 345 |
print(
|
| 346 |
f"--- prompt length = {len((pre_promt + processed_prompt).split(' '))} ---"
|
| 347 |
)
|
|
|
|
| 348 |
return pre_promt + processed_prompt
|
| 349 |
|
| 350 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 351 |
"""Generate one more bar from the input_prompt"""
|
| 352 |
-
processed_prompt = self.process_prompt_for_next_bar(
|
|
|
|
| 353 |
prompt_plus_bar = self.generate_until_track_end(
|
| 354 |
input_prompt=processed_prompt,
|
| 355 |
-
temperature=self.piece_by_track[
|
| 356 |
expected_length=1,
|
| 357 |
verbose=False,
|
| 358 |
)
|
| 359 |
added_bar = self.get_newly_generated_bar(prompt_plus_bar)
|
| 360 |
-
self.update_track_dict__add_bars(added_bar,
|
| 361 |
|
| 362 |
def get_newly_generated_bar(self, prompt_plus_bar):
|
| 363 |
return "BAR_START " + self.striping_track_ends(
|
|
@@ -380,7 +414,6 @@ class GenerateMidiText:
|
|
| 380 |
self.check_the_piece_for_errors()
|
| 381 |
|
| 382 |
def check_the_piece_for_errors(self, piece: str = None):
|
| 383 |
-
|
| 384 |
if piece is None:
|
| 385 |
piece = self.get_whole_piece_from_bar_dict()
|
| 386 |
errors = []
|
|
|
|
| 1 |
from generation_utils import *
|
| 2 |
+
import random
|
|
|
|
|
|
|
|
|
|
| 3 |
|
| 4 |
|
| 5 |
class GenerateMidiText:
|
|
|
|
| 97 |
text = text.rstrip(" ").rstrip("TRACK_END")
|
| 98 |
return text
|
| 99 |
|
| 100 |
+
def get_last_generated_track(self, piece):
|
| 101 |
+
"""Get the last track from a piece written as a single long string"""
|
| 102 |
+
track = self.get_tracks_from_a_piece(piece)[-1]
|
|
|
|
|
|
|
|
|
|
| 103 |
return track
|
| 104 |
|
| 105 |
+
def get_tracks_from_a_piece(self, piece):
|
| 106 |
+
"""Get all the tracks from a piece written as a single long string"""
|
| 107 |
+
all_tracks = [
|
| 108 |
+
"TRACK_START " + the_track + "TRACK_END "
|
| 109 |
+
for the_track in self.striping_track_ends(piece.split("TRACK_START ")[1::])
|
| 110 |
+
]
|
| 111 |
+
return all_tracks
|
| 112 |
+
|
| 113 |
+
def get_piece_from_track_list(self, track_list):
|
| 114 |
+
piece = "PIECE_START "
|
| 115 |
+
for track in track_list:
|
| 116 |
+
piece += track
|
| 117 |
+
return piece
|
| 118 |
+
|
| 119 |
+
def get_whole_track_from_bar_dict(self, track_id):
|
| 120 |
text = ""
|
| 121 |
for bar in self.piece_by_track[track_id]["bars"]:
|
| 122 |
text += bar
|
|
|
|
| 130 |
def get_whole_piece_from_bar_dict(self):
|
| 131 |
text = "PIECE_START "
|
| 132 |
for track_id, _ in enumerate(self.piece_by_track):
|
| 133 |
+
text += self.get_whole_track_from_bar_dict(track_id)
|
| 134 |
return text
|
| 135 |
|
| 136 |
+
def delete_one_track(self, track):
|
| 137 |
self.piece_by_track.pop(track)
|
| 138 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 139 |
"""Basic generation tools"""
|
| 140 |
|
| 141 |
def tokenize_input_prompt(self, input_prompt, verbose=True):
|
|
|
|
| 240 |
)
|
| 241 |
else:
|
| 242 |
print('"--- Wrong length - Regenerating ---')
|
| 243 |
+
|
| 244 |
if not bar_count_checks:
|
| 245 |
failed += 1
|
| 246 |
+
|
| 247 |
+
if failed > 2:
|
| 248 |
+
bar_count_checks = True # exit the while loop if failed too much
|
| 249 |
|
| 250 |
return full_piece
|
| 251 |
|
|
|
|
| 302 |
|
| 303 |
""" Piece generation - Extra Bars """
|
| 304 |
|
| 305 |
+
def process_prompt_for_next_bar(self, track_idx, verbose=True):
|
|
|
|
| 306 |
"""Processing the prompt for the model to generate one more bar only.
|
| 307 |
The prompt containts:
|
| 308 |
if not the first bar: the previous, already processed, bars of the track
|
|
|
|
| 321 |
if i != track_idx:
|
| 322 |
len_diff = len(othertrack["bars"]) - len(track["bars"])
|
| 323 |
if len_diff > 0:
|
| 324 |
+
if verbose:
|
| 325 |
+
print(
|
| 326 |
+
f"Adding bars - {len(track['bars'][-self.model_n_bar :])} selected from SIDE track: {i} for prompt"
|
| 327 |
+
)
|
| 328 |
# if other bars are longer, it mean that this one should catch up
|
| 329 |
pre_promt += othertrack["bars"][0]
|
| 330 |
for bar in track["bars"][-self.model_n_bar :]:
|
|
|
|
| 332 |
pre_promt += "TRACK_END "
|
| 333 |
elif (
|
| 334 |
False
|
| 335 |
+
): # len_diff <= 0: # THIS DOES NOT WORK - It just adds empty bars
|
| 336 |
# adding an empty bars at the end of the other tracks if they have not been processed yet
|
| 337 |
pre_promt += othertracks["bars"][0]
|
| 338 |
for bar in track["bars"][-(self.model_n_bar - 1) :]:
|
|
|
|
| 344 |
# for the bar to prolong
|
| 345 |
# initialization e.g TRACK_START INST=DRUMS DENSITY=2
|
| 346 |
processed_prompt = track["bars"][0]
|
| 347 |
+
if verbose:
|
| 348 |
+
print(
|
| 349 |
+
f"Adding bars - {len(track['bars'][-(self.model_n_bar - 1) :])} selected from MAIN track: {track_idx} for prompt"
|
| 350 |
+
)
|
| 351 |
for bar in track["bars"][-(self.model_n_bar - 1) :]:
|
| 352 |
# adding the "last" bars of the track
|
| 353 |
processed_prompt += bar
|
| 354 |
|
| 355 |
processed_prompt += "BAR_START "
|
| 356 |
+
|
| 357 |
+
# making the preprompt short enought to avoid bug due to length of the prompt (model limitation)
|
| 358 |
+
pre_promt = self.force_prompt_length(pre_promt, 1500)
|
| 359 |
+
|
| 360 |
print(
|
| 361 |
f"--- prompt length = {len((pre_promt + processed_prompt).split(' '))} ---"
|
| 362 |
)
|
| 363 |
+
|
| 364 |
return pre_promt + processed_prompt
|
| 365 |
|
| 366 |
+
def force_prompt_length(self, prompt, expected_length):
|
| 367 |
+
"""remove one instrument/track from the prompt it too long
|
| 368 |
+
Args:
|
| 369 |
+
prompt (str): the prompt to be processed
|
| 370 |
+
expected_length (int): the expected length of the prompt
|
| 371 |
+
Returns:
|
| 372 |
+
the truncated prompt"""
|
| 373 |
+
if len(prompt.split(" ")) < expected_length:
|
| 374 |
+
truncated_prompt = prompt
|
| 375 |
+
else:
|
| 376 |
+
tracks = self.get_tracks_from_a_piece(prompt)
|
| 377 |
+
selected_tracks = random.sample(tracks, len(tracks) - 1)
|
| 378 |
+
truncated_prompt = self.get_piece_from_track_list(selected_tracks)
|
| 379 |
+
print(f"Prompt too long - deleting one track")
|
| 380 |
+
|
| 381 |
+
return truncated_prompt
|
| 382 |
+
|
| 383 |
+
def generate_one_more_bar(self, track_index):
|
| 384 |
"""Generate one more bar from the input_prompt"""
|
| 385 |
+
processed_prompt = self.process_prompt_for_next_bar(track_index)
|
| 386 |
+
|
| 387 |
prompt_plus_bar = self.generate_until_track_end(
|
| 388 |
input_prompt=processed_prompt,
|
| 389 |
+
temperature=self.piece_by_track[track_index]["temperature"],
|
| 390 |
expected_length=1,
|
| 391 |
verbose=False,
|
| 392 |
)
|
| 393 |
added_bar = self.get_newly_generated_bar(prompt_plus_bar)
|
| 394 |
+
self.update_track_dict__add_bars(added_bar, track_index)
|
| 395 |
|
| 396 |
def get_newly_generated_bar(self, prompt_plus_bar):
|
| 397 |
return "BAR_START " + self.striping_track_ends(
|
|
|
|
| 414 |
self.check_the_piece_for_errors()
|
| 415 |
|
| 416 |
def check_the_piece_for_errors(self, piece: str = None):
|
|
|
|
| 417 |
if piece is None:
|
| 418 |
piece = self.get_whole_piece_from_bar_dict()
|
| 419 |
errors = []
|
generation_utils.py
CHANGED
|
@@ -2,6 +2,7 @@ import os
|
|
| 2 |
import numpy as np
|
| 3 |
import matplotlib.pyplot as plt
|
| 4 |
import matplotlib
|
|
|
|
| 5 |
|
| 6 |
from constants import INSTRUMENT_CLASSES
|
| 7 |
from playback import get_music, show_piano_roll
|
|
@@ -14,11 +15,38 @@ matplotlib.rcParams["axes.facecolor"] = "none"
|
|
| 14 |
matplotlib.rcParams["axes.edgecolor"] = "grey"
|
| 15 |
|
| 16 |
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 22 |
|
| 23 |
|
| 24 |
def bar_count_check(sequence, n_bars):
|
|
@@ -64,7 +92,8 @@ def check_if_prompt_density_in_tokenizer_vocab(tokenizer, density_prompt_list):
|
|
| 64 |
|
| 65 |
def forcing_bar_count(input_prompt, generated, bar_count, expected_length):
|
| 66 |
"""Forcing the generated sequence to have the expected length
|
| 67 |
-
expected_length and bar_count refers to the length of newly_generated_only (without input prompt)
|
|
|
|
| 68 |
|
| 69 |
if bar_count - expected_length > 0: # Cut the sequence if too long
|
| 70 |
full_piece = ""
|
|
|
|
| 2 |
import numpy as np
|
| 3 |
import matplotlib.pyplot as plt
|
| 4 |
import matplotlib
|
| 5 |
+
from utils import writeToFile, get_datetime
|
| 6 |
|
| 7 |
from constants import INSTRUMENT_CLASSES
|
| 8 |
from playback import get_music, show_piano_roll
|
|
|
|
| 15 |
matplotlib.rcParams["axes.edgecolor"] = "grey"
|
| 16 |
|
| 17 |
|
| 18 |
+
class WriteTextMidiToFile: # utils saving miditext from teh class GenerateMidiText to file
|
| 19 |
+
def __init__(self, generate_midi, output_path):
|
| 20 |
+
self.generated_midi = generate_midi.generated_piece
|
| 21 |
+
self.output_path = output_path
|
| 22 |
+
self.hyperparameter_and_bars = generate_midi.piece_by_track
|
| 23 |
+
|
| 24 |
+
def hashing_seq(self):
|
| 25 |
+
self.current_time = get_datetime()
|
| 26 |
+
self.output_path_filename = f"{self.output_path}/{self.current_time}.json"
|
| 27 |
+
|
| 28 |
+
def wrapping_seq_hyperparameters_in_dict(self):
|
| 29 |
+
# assert type(self.generated_midi) is str, "error: generate_midi must be a string"
|
| 30 |
+
# assert (
|
| 31 |
+
# type(self.hyperparameter_dict) is dict
|
| 32 |
+
# ), "error: feature_dict must be a dictionnary"
|
| 33 |
+
return {
|
| 34 |
+
"generated_midi": self.generated_midi,
|
| 35 |
+
"hyperparameters_and_bars": self.hyperparameter_and_bars,
|
| 36 |
+
}
|
| 37 |
+
|
| 38 |
+
def text_midi_to_file(self):
|
| 39 |
+
self.hashing_seq()
|
| 40 |
+
output_dict = self.wrapping_seq_hyperparameters_in_dict()
|
| 41 |
+
print(f"Token generate_midi written: {self.output_path_filename}")
|
| 42 |
+
writeToFile(self.output_path_filename, output_dict)
|
| 43 |
+
return self.output_path_filename
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def define_generation_dir(generation_dir):
|
| 47 |
+
if not os.path.exists(generation_dir):
|
| 48 |
+
os.makedirs(generation_dir)
|
| 49 |
+
return generation_dir
|
| 50 |
|
| 51 |
|
| 52 |
def bar_count_check(sequence, n_bars):
|
|
|
|
| 92 |
|
| 93 |
def forcing_bar_count(input_prompt, generated, bar_count, expected_length):
|
| 94 |
"""Forcing the generated sequence to have the expected length
|
| 95 |
+
expected_length and bar_count refers to the length of newly_generated_only (without input prompt)
|
| 96 |
+
"""
|
| 97 |
|
| 98 |
if bar_count - expected_length > 0: # Cut the sequence if too long
|
| 99 |
full_piece = ""
|