Spaces:
Sleeping
Sleeping
| from typing import List, Tuple | |
| import gradio as gr | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| import note_seq | |
| from matplotlib.figure import Figure | |
| from numpy import ndarray | |
| import torch | |
| from constants import GM_INSTRUMENTS, SAMPLE_RATE | |
| from string_to_notes import token_sequence_to_note_sequence | |
| from model import get_model_and_tokenizer | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| # Load the tokenizer and the model | |
| model, tokenizer = get_model_and_tokenizer() | |
| def create_seed_string(genre: str = "OTHER") -> str: | |
| if genre == "RANDOM": | |
| seed_string = "PIECE_START" | |
| else: | |
| seed_string = f"PIECE_START GENRE={genre} TRACK_START" | |
| return seed_string | |
| def get_instruments(text_sequence: str) -> List[str]: | |
| """ | |
| Extracts the list of instruments from a text sequence. | |
| Args: | |
| text_sequence (str): The text sequence. | |
| Returns: | |
| List[str]: The list of instruments. | |
| """ | |
| instruments = [] | |
| parts = text_sequence.split() | |
| for part in parts: | |
| if part.startswith("INST="): | |
| if part[5:] == "DRUMS": | |
| instruments.append("Drums") | |
| else: | |
| index = int(part[5:]) | |
| instruments.append(GM_INSTRUMENTS[index]) | |
| return instruments | |
| def generate_new_instrument(seed: str, temp: float = 0.75) -> str: | |
| seed_length = len(tokenizer.encode(seed)) | |
| while True: | |
| # Encode the conditioning tokens. | |
| input_ids = tokenizer.encode(seed, return_tensors="pt") | |
| # Move the input_ids tensor to the same device as the model | |
| input_ids = input_ids.to(model.device) | |
| # Generate more tokens. | |
| eos_token_id = tokenizer.encode("TRACK_END")[0] | |
| generated_ids = model.generate( | |
| input_ids, | |
| max_new_tokens=2048, | |
| do_sample=True, | |
| temperature=temp, | |
| eos_token_id=eos_token_id, | |
| ) | |
| generated_sequence = tokenizer.decode(generated_ids[0]) | |
| # Check if the generated sequence contains "NOTE_ON" beyond the seed | |
| new_generated_sequence = tokenizer.decode(generated_ids[0][seed_length:]) | |
| if "NOTE_ON" in new_generated_sequence: | |
| return generated_sequence | |
| def get_outputs_from_string( | |
| generated_sequence: str, qpm: int = 120 | |
| ) -> Tuple[ndarray, str, Figure, str, str]: | |
| instruments = get_instruments(generated_sequence) | |
| instruments_str = "\n".join(f"- {instrument}" for instrument in instruments) | |
| note_sequence = token_sequence_to_note_sequence(generated_sequence, qpm=qpm) | |
| synth = note_seq.fluidsynth | |
| array_of_floats = synth(note_sequence, sample_rate=SAMPLE_RATE) | |
| int16_data = note_seq.audio_io.float_samples_to_int16(array_of_floats) | |
| fig = note_seq.plot_sequence(note_sequence, show_figure=False) | |
| num_tokens = str(len(generated_sequence.split())) | |
| audio = gr.make_waveform((SAMPLE_RATE, int16_data)) | |
| note_seq.note_sequence_to_midi_file(note_sequence, "midi_ouput.mid") | |
| return audio, "midi_ouput.mid", fig, instruments_str, num_tokens | |
| def remove_last_instrument( | |
| text_sequence: str, qpm: int = 120 | |
| ) -> Tuple[ndarray, str, Figure, str, str, str]: | |
| # We split the song into tracks by splitting on 'TRACK_START' | |
| tracks = text_sequence.split("TRACK_START") | |
| # We keep all tracks except the last one | |
| modified_tracks = tracks[:-1] | |
| # We join the tracks back together, adding back the 'TRACK_START' that was removed by split | |
| new_song = "TRACK_START".join(modified_tracks) | |
| if len(tracks) == 2: | |
| # There is only one instrument, so start from scratch | |
| audio, midi_file, fig, instruments_str, new_song, num_tokens = generate_song( | |
| text_sequence=new_song | |
| ) | |
| elif len(tracks) == 1: | |
| # No instrument so start from empty sequence | |
| audio, midi_file, fig, instruments_str, new_song, num_tokens = generate_song( | |
| text_sequence="" | |
| ) | |
| else: | |
| audio, midi_file, fig, instruments_str, num_tokens = get_outputs_from_string( | |
| new_song, qpm | |
| ) | |
| return audio, midi_file, fig, instruments_str, new_song, num_tokens | |
| def regenerate_last_instrument( | |
| text_sequence: str, qpm: int = 120 | |
| ) -> Tuple[ndarray, str, Figure, str, str, str]: | |
| last_inst_index = text_sequence.rfind("INST=") | |
| if last_inst_index == -1: | |
| # No instrument so start from empty sequence | |
| audio, midi_file, fig, instruments_str, new_song, num_tokens = generate_song( | |
| text_sequence="", qpm=qpm | |
| ) | |
| else: | |
| # Take it from the last instrument and continue generation | |
| next_space_index = text_sequence.find(" ", last_inst_index) | |
| new_seed = text_sequence[:next_space_index] | |
| audio, midi_file, fig, instruments_str, new_song, num_tokens = generate_song( | |
| text_sequence=new_seed, qpm=qpm | |
| ) | |
| return audio, midi_file, fig, instruments_str, new_song, num_tokens | |
| def change_tempo( | |
| text_sequence: str, qpm: int | |
| ) -> Tuple[ndarray, str, Figure, str, str, str]: | |
| audio, midi_file, fig, instruments_str, num_tokens = get_outputs_from_string( | |
| text_sequence, qpm=qpm | |
| ) | |
| return audio, midi_file, fig, instruments_str, text_sequence, num_tokens | |
| def generate_song( | |
| genre: str = "OTHER", | |
| temp: float = 0.75, | |
| text_sequence: str = "", | |
| qpm: int = 120, | |
| ) -> Tuple[ndarray, str, Figure, str, str, str]: | |
| if text_sequence == "": | |
| seed_string = create_seed_string(genre) | |
| else: | |
| seed_string = text_sequence | |
| generated_sequence = generate_new_instrument(seed=seed_string, temp=temp) | |
| audio, midi_file, fig, instruments_str, num_tokens = get_outputs_from_string( | |
| generated_sequence, qpm | |
| ) | |
| return audio, midi_file, fig, instruments_str, generated_sequence, num_tokens | |