Spaces:
Runtime error
Runtime error
| from generation_utils import * | |
| from utils import WriteTextMidiToFile, get_miditok | |
| from load import LoadModel | |
| from decoder import TextDecoder | |
| from playback import get_music | |
| class GenerateMidiText: | |
| """Generating music with Class | |
| LOGIC: | |
| FOR GENERATING FROM SCRATCH: | |
| - self.generate_one_new_track() | |
| it calls | |
| - self.generate_until_track_end() | |
| FOR GENERATING NEW BARS: | |
| - self.generate_one_more_bar() | |
| it calls | |
| - self.process_prompt_for_next_bar() | |
| - self.generate_until_track_end()""" | |
| def __init__(self, model, tokenizer, piece_by_track=[]): | |
| self.model = model | |
| self.tokenizer = tokenizer | |
| # default initialization | |
| self.initialize_default_parameters() | |
| self.initialize_dictionaries(piece_by_track) | |
| """Setters""" | |
| def initialize_default_parameters(self): | |
| self.set_device() | |
| self.set_attention_length() | |
| self.generate_until = "TRACK_END" | |
| self.set_force_sequence_lenth() | |
| self.set_nb_bars_generated() | |
| self.set_improvisation_level(0) | |
| def initialize_dictionaries(self, piece_by_track): | |
| self.piece_by_track = piece_by_track | |
| def set_device(self, device="cpu"): | |
| self.device = ("cpu",) | |
| def set_attention_length(self): | |
| self.max_length = self.model.config.n_positions | |
| print( | |
| f"Attention length set to {self.max_length} -> 'model.config.n_positions'" | |
| ) | |
| def set_force_sequence_lenth(self, force_sequence_length=True): | |
| self.force_sequence_length = force_sequence_length | |
| def set_improvisation_level(self, improvisation_value): | |
| self.no_repeat_ngram_size = improvisation_value | |
| print("--------------------") | |
| print(f"no_repeat_ngram_size set to {improvisation_value}") | |
| print("--------------------") | |
| def reset_temperatures(self, track_id, temperature): | |
| self.piece_by_track[track_id]["temperature"] = temperature | |
| def set_nb_bars_generated(self, n_bars=8): # default is a 8 bar model | |
| self.model_n_bar = n_bars | |
| """ Generation Tools - Dictionnaries """ | |
| def initiate_track_dict(self, instr, density, temperature): | |
| label = len(self.piece_by_track) | |
| self.piece_by_track.append( | |
| { | |
| "label": f"track_{label}", | |
| "instrument": instr, | |
| "density": density, | |
| "temperature": temperature, | |
| "bars": [], | |
| } | |
| ) | |
| def update_track_dict__add_bars(self, bars, track_id): | |
| """Add bars to the track dictionnary""" | |
| for bar in self.striping_track_ends(bars).split("BAR_START "): | |
| if bar == "": # happens is there is one bar only | |
| continue | |
| else: | |
| if "TRACK_START" in bar: | |
| self.piece_by_track[track_id]["bars"].append(bar) | |
| else: | |
| self.piece_by_track[track_id]["bars"].append("BAR_START " + bar) | |
| def get_all_instr_bars(self, track_id): | |
| return self.piece_by_track[track_id]["bars"] | |
| def striping_track_ends(self, text): | |
| if "TRACK_END" in text: | |
| # first get rid of extra space if any | |
| # then gets rid of "TRACK_END" | |
| text = text.rstrip(" ").rstrip("TRACK_END") | |
| return text | |
| def get_last_generated_track(self, full_piece): | |
| track = ( | |
| "TRACK_START " | |
| + self.striping_track_ends(full_piece.split("TRACK_START ")[-1]) | |
| + "TRACK_END " | |
| ) # forcing the space after track and | |
| return track | |
| def get_selected_track_as_text(self, track_id): | |
| text = "" | |
| for bar in self.piece_by_track[track_id]["bars"]: | |
| text += bar | |
| text += "TRACK_END " | |
| return text | |
| def get_newly_generated_text(input_prompt, full_piece): | |
| return full_piece[len(input_prompt) :] | |
| def get_whole_piece_from_bar_dict(self): | |
| text = "PIECE_START " | |
| for track_id, _ in enumerate(self.piece_by_track): | |
| text += self.get_selected_track_as_text(track_id) | |
| return text | |
| def delete_one_track(self, track): # TO BE TESTED | |
| self.piece_by_track.pop(track) | |
| # def update_piece_dict__add_track(self, track_id, track): | |
| # self.piece_dict[track_id] = track | |
| # def update_all_dictionnaries__add_track(self, track): | |
| # self.update_piece_dict__add_track(track_id, track) | |
| """Basic generation tools""" | |
| def tokenize_input_prompt(self, input_prompt, verbose=True): | |
| """Tokenizing prompt | |
| Args: | |
| - input_prompt (str): prompt to tokenize | |
| Returns: | |
| - input_prompt_ids (torch.tensor): tokenized prompt | |
| """ | |
| if verbose: | |
| print("Tokenizing input_prompt...") | |
| return self.tokenizer.encode(input_prompt, return_tensors="pt") | |
| def generate_sequence_of_token_ids( | |
| self, | |
| input_prompt_ids, | |
| temperature, | |
| verbose=True, | |
| ): | |
| """ | |
| generate a sequence of token ids based on input_prompt_ids | |
| The sequence length depends on the trained model (self.model_n_bar) | |
| """ | |
| generated_ids = self.model.generate( | |
| input_prompt_ids, | |
| max_length=self.max_length, | |
| do_sample=True, | |
| temperature=temperature, | |
| no_repeat_ngram_size=self.no_repeat_ngram_size, # default = 0 | |
| eos_token_id=self.tokenizer.encode(self.generate_until)[0], # good | |
| ) | |
| if verbose: | |
| print("Generating a token_id sequence...") | |
| return generated_ids | |
| def convert_ids_to_text(self, generated_ids, verbose=True): | |
| """converts the token_ids to text""" | |
| generated_text = self.tokenizer.decode(generated_ids[0]) | |
| if verbose: | |
| print("Converting token sequence to MidiText...") | |
| return generated_text | |
| def generate_until_track_end( | |
| self, | |
| input_prompt="PIECE_START ", | |
| instrument=None, | |
| density=None, | |
| temperature=None, | |
| verbose=True, | |
| expected_length=None, | |
| ): | |
| """generate until the TRACK_END token is reached | |
| full_piece = input_prompt + generated""" | |
| if expected_length is None: | |
| expected_length = self.model_n_bar | |
| if instrument is not None: | |
| input_prompt = f"{input_prompt}TRACK_START INST={str(instrument)} " | |
| if density is not None: | |
| input_prompt = f"{input_prompt}DENSITY={str(density)} " | |
| if instrument is None and density is not None: | |
| print("Density cannot be defined without an input_prompt instrument #TOFIX") | |
| if temperature is None: | |
| ValueError("Temperature must be defined") | |
| if verbose: | |
| print("--------------------") | |
| print( | |
| f"Generating {instrument} - Density {density} - temperature {temperature}" | |
| ) | |
| bar_count_checks = False | |
| failed = 0 | |
| while not bar_count_checks: # regenerate until right length | |
| input_prompt_ids = self.tokenize_input_prompt(input_prompt, verbose=verbose) | |
| generated_tokens = self.generate_sequence_of_token_ids( | |
| input_prompt_ids, temperature, verbose=verbose | |
| ) | |
| full_piece = self.convert_ids_to_text(generated_tokens, verbose=verbose) | |
| generated = self.get_newly_generated_text(input_prompt, full_piece) | |
| # bar_count_checks | |
| bar_count_checks, bar_count = bar_count_check(generated, expected_length) | |
| if not self.force_sequence_length: | |
| # set bar_count_checks to true to exist the while loop | |
| bar_count_checks = True | |
| if not bar_count_checks and self.force_sequence_length: | |
| # if the generated sequence is not the expected length | |
| if failed > -1: # deactivated for speed | |
| full_piece, bar_count_checks = forcing_bar_count( | |
| input_prompt, | |
| generated, | |
| bar_count, | |
| expected_length, | |
| ) | |
| else: | |
| print('"--- Wrong length - Regenerating ---') | |
| if not bar_count_checks: | |
| failed += 1 | |
| if failed > 2: | |
| bar_count_checks = True # TOFIX exit the while loop | |
| return full_piece | |
| def generate_one_new_track( | |
| self, | |
| instrument, | |
| density, | |
| temperature, | |
| input_prompt="PIECE_START ", | |
| ): | |
| self.initiate_track_dict(instrument, density, temperature) | |
| full_piece = self.generate_until_track_end( | |
| input_prompt=input_prompt, | |
| instrument=instrument, | |
| density=density, | |
| temperature=temperature, | |
| ) | |
| track = self.get_last_generated_track(full_piece) | |
| self.update_track_dict__add_bars(track, -1) | |
| full_piece = self.get_whole_piece_from_bar_dict() | |
| return full_piece | |
| """ Piece generation - Basics """ | |
| def generate_piece(self, instrument_list, density_list, temperature_list): | |
| """generate a sequence with mutiple tracks | |
| - inst_list sets the list of instruments of the order of generation | |
| - density is paired with inst_list | |
| Each track/intrument is generated on a prompt which contains the previously generated track/instrument | |
| This means that the first instrument is generated with less bias than the next one, and so on. | |
| 'generated_piece' keeps track of the entire piece | |
| 'generated_piece' is returned by self.generate_until_track_end | |
| # it is returned by self.generate_until_track_end""" | |
| generated_piece = "PIECE_START " | |
| for instrument, density, temperature in zip( | |
| instrument_list, density_list, temperature_list | |
| ): | |
| generated_piece = self.generate_one_new_track( | |
| instrument, | |
| density, | |
| temperature, | |
| input_prompt=generated_piece, | |
| ) | |
| # generated_piece = self.get_whole_piece_from_bar_dict() | |
| self.check_the_piece_for_errors() | |
| return generated_piece | |
| """ Piece generation - Extra Bars """ | |
| def process_prompt_for_next_bar(self, track_idx): | |
| """Processing the prompt for the model to generate one more bar only. | |
| The prompt containts: | |
| if not the first bar: the previous, already processed, bars of the track | |
| the bar initialization (ex: "TRACK_START INST=DRUMS DENSITY=2 ") | |
| the last (self.model_n_bar)-1 bars of the track | |
| Args: | |
| track_idx (int): the index of the track to be processed | |
| Returns: | |
| the processed prompt for generating the next bar | |
| """ | |
| track = self.piece_by_track[track_idx] | |
| # for bars which are not the bar to prolong | |
| pre_promt = "PIECE_START " | |
| for i, othertrack in enumerate(self.piece_by_track): | |
| if i != track_idx: | |
| len_diff = len(othertrack["bars"]) - len(track["bars"]) | |
| if len_diff > 0: | |
| # if other bars are longer, it mean that this one should catch up | |
| pre_promt += othertrack["bars"][0] | |
| for bar in track["bars"][-self.model_n_bar :]: | |
| pre_promt += bar | |
| pre_promt += "TRACK_END " | |
| elif False: # len_diff <= 0: # THIS GENERATES EMPTINESS | |
| # adding an empty bars at the end of the other tracks if they have not been processed yet | |
| pre_promt += othertracks["bars"][0] | |
| for bar in track["bars"][-(self.model_n_bar - 1) :]: | |
| pre_promt += bar | |
| for _ in range(abs(len_diff) + 1): | |
| pre_promt += "BAR_START BAR_END " | |
| pre_promt += "TRACK_END " | |
| # for the bar to prolong | |
| # initialization e.g TRACK_START INST=DRUMS DENSITY=2 | |
| processed_prompt = track["bars"][0] | |
| for bar in track["bars"][-(self.model_n_bar - 1) :]: | |
| # adding the "last" bars of the track | |
| processed_prompt += bar | |
| processed_prompt += "BAR_START " | |
| print( | |
| f"--- prompt length = {len((pre_promt + processed_prompt).split(' '))} ---" | |
| ) | |
| return pre_promt + processed_prompt | |
| def generate_one_more_bar(self, i): | |
| """Generate one more bar from the input_prompt""" | |
| processed_prompt = self.process_prompt_for_next_bar(self, i) | |
| prompt_plus_bar = self.generate_until_track_end( | |
| input_prompt=processed_prompt, | |
| temperature=self.piece_by_track[i]["temperature"], | |
| expected_length=1, | |
| verbose=False, | |
| ) | |
| added_bar = self.get_newly_generated_bar(prompt_plus_bar) | |
| self.update_track_dict__add_bars(added_bar, i) | |
| def get_newly_generated_bar(self, prompt_plus_bar): | |
| return "BAR_START " + self.striping_track_ends( | |
| prompt_plus_bar.split("BAR_START ")[-1] | |
| ) | |
| def generate_n_more_bars(self, n_bars, only_this_track=None, verbose=True): | |
| """Generate n more bars from the input_prompt""" | |
| if only_this_track is None: | |
| only_this_track | |
| print(f"================== ") | |
| print(f"Adding {n_bars} more bars to the piece ") | |
| for bar_id in range(n_bars): | |
| print(f"----- added bar #{bar_id+1} --") | |
| for i, track in enumerate(self.piece_by_track): | |
| if only_this_track is None or i == only_this_track: | |
| print(f"--------- {track['label']}") | |
| self.generate_one_more_bar(i) | |
| self.check_the_piece_for_errors() | |
| def check_the_piece_for_errors(self, piece: str = None): | |
| if piece is None: | |
| piece = generate_midi.get_whole_piece_from_bar_dict() | |
| errors = [] | |
| errors.append( | |
| [ | |
| (token, id) | |
| for id, token in enumerate(piece.split(" ")) | |
| if token not in self.tokenizer.vocab or token == "UNK" | |
| ] | |
| ) | |
| if len(errors) > 0: | |
| # print(piece) | |
| for er in errors: | |
| er | |
| print(f"Token not found in the piece at {er[0][1]}: {er[0][0]}") | |
| print(piece.split(" ")[er[0][1] - 5 : er[0][1] + 5]) | |
| if __name__ == "__main__": | |
| # worker | |
| DEVICE = "cpu" | |
| # define generation parameters | |
| N_FILES_TO_GENERATE = 2 | |
| Temperatures_to_try = [0.7] | |
| USE_FAMILIZED_MODEL = True | |
| force_sequence_length = True | |
| if USE_FAMILIZED_MODEL: | |
| # model_repo = "misnaej/the-jam-machine-elec-famil" | |
| # model_repo = "misnaej/the-jam-machine-elec-famil-ft32" | |
| # model_repo = "JammyMachina/elec-gmusic-familized-model-13-12__17-35-53" | |
| # n_bar_generated = 8 | |
| model_repo = "JammyMachina/improved_4bars-mdl" | |
| n_bar_generated = 4 | |
| instrument_promt_list = ["4", "DRUMS", "3"] | |
| # DRUMS = drums, 0 = piano, 1 = chromatic percussion, 2 = organ, 3 = guitar, 4 = bass, 5 = strings, 6 = ensemble, 7 = brass, 8 = reed, 9 = pipe, 10 = synth lead, 11 = synth pad, 12 = synth effects, 13 = ethnic, 14 = percussive, 15 = sound effects | |
| density_list = [3, 2, 2] | |
| # temperature_list = [0.7, 0.7, 0.75] | |
| else: | |
| model_repo = "misnaej/the-jam-machine" | |
| instrument_promt_list = ["30"] # , "DRUMS", "0"] | |
| density_list = [3] # , 2, 3] | |
| # temperature_list = [0.7, 0.5, 0.75] | |
| pass | |
| # define generation directory | |
| generated_sequence_files_path = define_generation_dir(model_repo) | |
| # load model and tokenizer | |
| model, tokenizer = LoadModel( | |
| model_repo, from_huggingface=True | |
| ).load_model_and_tokenizer() | |
| # does the prompt make sense | |
| check_if_prompt_inst_in_tokenizer_vocab(tokenizer, instrument_promt_list) | |
| for temperature in Temperatures_to_try: | |
| print(f"================= TEMPERATURE {temperature} =======================") | |
| for _ in range(N_FILES_TO_GENERATE): | |
| print(f"========================================") | |
| # 1 - instantiate | |
| generate_midi = GenerateMidiText(model, tokenizer) | |
| # 0 - set the n_bar for this model | |
| generate_midi.set_nb_bars_generated(n_bars=n_bar_generated) | |
| # 1 - defines the instruments, densities and temperatures | |
| # 2- generate the first 8 bars for each instrument | |
| generate_midi.set_improvisation_level(30) | |
| generate_midi.generate_piece( | |
| instrument_promt_list, | |
| density_list, | |
| [temperature for _ in density_list], | |
| ) | |
| # 3 - force the model to improvise | |
| # generate_midi.set_improvisation_level(20) | |
| # 4 - generate the next 4 bars for each instrument | |
| # generate_midi.generate_n_more_bars(n_bar_generated) | |
| # 5 - lower the improvisation level | |
| generate_midi.generated_piece = ( | |
| generate_midi.get_whole_piece_from_bar_dict() | |
| ) | |
| # print the generated sequence in terminal | |
| print("=========================================") | |
| print(generate_midi.generated_piece) | |
| print("=========================================") | |
| # write to JSON file | |
| filename = WriteTextMidiToFile( | |
| generate_midi, | |
| generated_sequence_files_path, | |
| ).text_midi_to_file() | |
| # decode the sequence to MIDI """ | |
| decode_tokenizer = get_miditok() | |
| TextDecoder(decode_tokenizer, USE_FAMILIZED_MODEL).get_midi( | |
| generate_midi.generated_piece, filename=filename.split(".")[0] + ".mid" | |
| ) | |
| inst_midi, mixed_audio = get_music(filename.split(".")[0] + ".mid") | |
| max_time = get_max_time(inst_midi) | |
| plot_piano_roll(inst_midi) | |
| print("Et voilà! Your MIDI file is ready! GO JAM!") | |