m41w4r3.exe
commited on
Commit
·
6cc2135
1
Parent(s):
facf84e
fix genesis caching
Browse files- decoder.py +1 -1
- generate.py +4 -4
- generation_utils.py +29 -9
- playground.py +56 -38
decoder.py
CHANGED
|
@@ -178,7 +178,7 @@ class TextDecoder:
|
|
| 178 |
inst = 0
|
| 179 |
is_drum = 1
|
| 180 |
if self.familized:
|
| 181 |
-
inst = Familizer(arbitrary=True).get_program_number(int(inst))
|
| 182 |
instruments.append((int(inst), is_drum))
|
| 183 |
return tuple(instruments)
|
| 184 |
|
|
|
|
| 178 |
inst = 0
|
| 179 |
is_drum = 1
|
| 180 |
if self.familized:
|
| 181 |
+
inst = Familizer(arbitrary=True).get_program_number(int(inst))
|
| 182 |
instruments.append((int(inst), is_drum))
|
| 183 |
return tuple(instruments)
|
| 184 |
|
generate.py
CHANGED
|
@@ -21,12 +21,12 @@ class GenerateMidiText:
|
|
| 21 |
- self.process_prompt_for_next_bar()
|
| 22 |
- self.generate_until_track_end()"""
|
| 23 |
|
| 24 |
-
def __init__(self, model, tokenizer):
|
| 25 |
self.model = model
|
| 26 |
self.tokenizer = tokenizer
|
| 27 |
# default initialization
|
| 28 |
self.initialize_default_parameters()
|
| 29 |
-
self.initialize_dictionaries()
|
| 30 |
|
| 31 |
"""Setters"""
|
| 32 |
|
|
@@ -38,8 +38,8 @@ class GenerateMidiText:
|
|
| 38 |
self.set_nb_bars_generated()
|
| 39 |
self.set_improvisation_level(0)
|
| 40 |
|
| 41 |
-
def initialize_dictionaries(self):
|
| 42 |
-
self.piece_by_track =
|
| 43 |
|
| 44 |
def set_device(self, device="cpu"):
|
| 45 |
self.device = ("cpu",)
|
|
|
|
| 21 |
- self.process_prompt_for_next_bar()
|
| 22 |
- self.generate_until_track_end()"""
|
| 23 |
|
| 24 |
+
def __init__(self, model, tokenizer, piece_by_track=[]):
|
| 25 |
self.model = model
|
| 26 |
self.tokenizer = tokenizer
|
| 27 |
# default initialization
|
| 28 |
self.initialize_default_parameters()
|
| 29 |
+
self.initialize_dictionaries(piece_by_track)
|
| 30 |
|
| 31 |
"""Setters"""
|
| 32 |
|
|
|
|
| 38 |
self.set_nb_bars_generated()
|
| 39 |
self.set_improvisation_level(0)
|
| 40 |
|
| 41 |
+
def initialize_dictionaries(self, piece_by_track):
|
| 42 |
+
self.piece_by_track = piece_by_track
|
| 43 |
|
| 44 |
def set_device(self, device="cpu"):
|
| 45 |
self.device = ("cpu",)
|
generation_utils.py
CHANGED
|
@@ -2,14 +2,16 @@ import os
|
|
| 2 |
import numpy as np
|
| 3 |
import matplotlib.pyplot as plt
|
| 4 |
import matplotlib
|
|
|
|
| 5 |
from constants import INSTRUMENT_CLASSES
|
|
|
|
| 6 |
|
| 7 |
# matplotlib settings
|
| 8 |
matplotlib.use("Agg") # for server
|
| 9 |
matplotlib.rcParams["xtick.major.size"] = 0
|
| 10 |
matplotlib.rcParams["ytick.major.size"] = 0
|
| 11 |
-
matplotlib.rcParams["axes.facecolor"] = "
|
| 12 |
-
matplotlib.rcParams["axes.edgecolor"] = "
|
| 13 |
|
| 14 |
|
| 15 |
def define_generation_dir(model_repo_path):
|
|
@@ -93,7 +95,7 @@ def get_max_time(inst_midi):
|
|
| 93 |
def plot_piano_roll(inst_midi):
|
| 94 |
piano_roll_fig = plt.figure(figsize=(25, 3 * len(inst_midi.instruments)))
|
| 95 |
piano_roll_fig.tight_layout()
|
| 96 |
-
piano_roll_fig.patch.set_alpha(0
|
| 97 |
inst_count = 0
|
| 98 |
beats_per_bar = 4
|
| 99 |
sec_per_beat = 0.5
|
|
@@ -102,6 +104,14 @@ def plot_piano_roll(inst_midi):
|
|
| 102 |
int
|
| 103 |
)
|
| 104 |
for inst in inst_midi.instruments:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 105 |
inst_count += 1
|
| 106 |
plt.subplot(len(inst_midi.instruments), 1, inst_count)
|
| 107 |
|
|
@@ -118,24 +128,34 @@ def plot_piano_roll(inst_midi):
|
|
| 118 |
for note in p_midi_note_list:
|
| 119 |
note_time.append([note.start, note.end])
|
| 120 |
note_pitch.append([note.pitch, note.pitch])
|
|
|
|
|
|
|
| 121 |
|
| 122 |
plt.plot(
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
color=
|
| 126 |
-
linewidth=
|
| 127 |
solid_capstyle="butt",
|
| 128 |
)
|
| 129 |
plt.ylim(0, 128)
|
| 130 |
xticks = np.array(bars_time)[:-1]
|
| 131 |
plt.tight_layout()
|
| 132 |
plt.xlim(min(bars_time), max(bars_time))
|
| 133 |
-
|
| 134 |
plt.xticks(
|
| 135 |
xticks + 0.5 * beats_per_bar * sec_per_beat,
|
| 136 |
labels=xticks.argsort() + 1,
|
| 137 |
visible=False,
|
| 138 |
)
|
| 139 |
-
plt.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 140 |
|
| 141 |
return piano_roll_fig
|
|
|
|
| 2 |
import numpy as np
|
| 3 |
import matplotlib.pyplot as plt
|
| 4 |
import matplotlib
|
| 5 |
+
|
| 6 |
from constants import INSTRUMENT_CLASSES
|
| 7 |
+
from playback import get_music, show_piano_roll
|
| 8 |
|
| 9 |
# matplotlib settings
|
| 10 |
matplotlib.use("Agg") # for server
|
| 11 |
matplotlib.rcParams["xtick.major.size"] = 0
|
| 12 |
matplotlib.rcParams["ytick.major.size"] = 0
|
| 13 |
+
matplotlib.rcParams["axes.facecolor"] = "none"
|
| 14 |
+
matplotlib.rcParams["axes.edgecolor"] = "grey"
|
| 15 |
|
| 16 |
|
| 17 |
def define_generation_dir(model_repo_path):
|
|
|
|
| 95 |
def plot_piano_roll(inst_midi):
|
| 96 |
piano_roll_fig = plt.figure(figsize=(25, 3 * len(inst_midi.instruments)))
|
| 97 |
piano_roll_fig.tight_layout()
|
| 98 |
+
piano_roll_fig.patch.set_alpha(0)
|
| 99 |
inst_count = 0
|
| 100 |
beats_per_bar = 4
|
| 101 |
sec_per_beat = 0.5
|
|
|
|
| 104 |
int
|
| 105 |
)
|
| 106 |
for inst in inst_midi.instruments:
|
| 107 |
+
# hardcoded for now
|
| 108 |
+
if inst.name == "Drums":
|
| 109 |
+
color = "purple"
|
| 110 |
+
elif inst.name == "Synth Bass 1":
|
| 111 |
+
color = "orange"
|
| 112 |
+
else:
|
| 113 |
+
color = "green"
|
| 114 |
+
|
| 115 |
inst_count += 1
|
| 116 |
plt.subplot(len(inst_midi.instruments), 1, inst_count)
|
| 117 |
|
|
|
|
| 128 |
for note in p_midi_note_list:
|
| 129 |
note_time.append([note.start, note.end])
|
| 130 |
note_pitch.append([note.pitch, note.pitch])
|
| 131 |
+
note_pitch = np.array(note_pitch)
|
| 132 |
+
note_time = np.array(note_time)
|
| 133 |
|
| 134 |
plt.plot(
|
| 135 |
+
note_time.T,
|
| 136 |
+
note_pitch.T,
|
| 137 |
+
color=color,
|
| 138 |
+
linewidth=4,
|
| 139 |
solid_capstyle="butt",
|
| 140 |
)
|
| 141 |
plt.ylim(0, 128)
|
| 142 |
xticks = np.array(bars_time)[:-1]
|
| 143 |
plt.tight_layout()
|
| 144 |
plt.xlim(min(bars_time), max(bars_time))
|
| 145 |
+
plt.ylim(max([note_pitch.min() - 5, 0]), note_pitch.max() + 5)
|
| 146 |
plt.xticks(
|
| 147 |
xticks + 0.5 * beats_per_bar * sec_per_beat,
|
| 148 |
labels=xticks.argsort() + 1,
|
| 149 |
visible=False,
|
| 150 |
)
|
| 151 |
+
plt.text(
|
| 152 |
+
0.2,
|
| 153 |
+
note_pitch.max() + 4,
|
| 154 |
+
inst.name,
|
| 155 |
+
fontsize=20,
|
| 156 |
+
color=color,
|
| 157 |
+
horizontalalignment="left",
|
| 158 |
+
verticalalignment="top",
|
| 159 |
+
)
|
| 160 |
|
| 161 |
return piano_roll_fig
|
playground.py
CHANGED
|
@@ -26,7 +26,6 @@ model, tokenizer = LoadModel(
|
|
| 26 |
model_repo, from_huggingface=True, revision=revision
|
| 27 |
).load_model_and_tokenizer()
|
| 28 |
|
| 29 |
-
|
| 30 |
miditok = get_miditok()
|
| 31 |
decoder = TextDecoder(miditok)
|
| 32 |
|
|
@@ -40,32 +39,49 @@ def define_prompt(state, genesis):
|
|
| 40 |
|
| 41 |
|
| 42 |
def generator(
|
| 43 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 44 |
):
|
| 45 |
|
|
|
|
|
|
|
| 46 |
inst = next(
|
| 47 |
(inst for inst in INSTRUMENT_CLASSES if inst["name"] == instrument),
|
| 48 |
{"family_number": "DRUMS"},
|
| 49 |
)["family_number"]
|
| 50 |
|
| 51 |
-
inst_index =
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
genesis.delete_one_track(inst_index)
|
| 57 |
-
generated_text = (
|
| 58 |
-
genesis.get_whole_piece_from_bar_dict()
|
| 59 |
-
) # maybe not useful here
|
| 60 |
-
inst_index = -1 # reset to last generated
|
| 61 |
|
| 62 |
# Generate
|
| 63 |
if not add_bars:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 64 |
# NEW TRACK
|
| 65 |
input_prompt = define_prompt(state, genesis)
|
| 66 |
generated_text = genesis.generate_one_new_track(
|
| 67 |
inst, density, temp, input_prompt=input_prompt
|
| 68 |
)
|
|
|
|
|
|
|
| 69 |
else:
|
| 70 |
# NEW BARS
|
| 71 |
genesis.generate_n_more_bars(add_bar_count) # for all instruments
|
|
@@ -79,14 +95,23 @@ def generator(
|
|
| 79 |
decoder.get_midi(inst_text, inst_midi_name)
|
| 80 |
_, inst_audio = get_music(inst_midi_name)
|
| 81 |
piano_roll = plot_piano_roll(mixed_inst_midi)
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 86 |
|
| 87 |
-
def instrument_row(default_inst):
|
| 88 |
|
|
|
|
| 89 |
with gr.Row():
|
|
|
|
| 90 |
with gr.Column(scale=1, min_width=50):
|
| 91 |
inst = gr.Dropdown(
|
| 92 |
[inst["name"] for inst in INSTRUMENT_CLASSES] + ["Drums"],
|
|
@@ -100,35 +125,33 @@ def instrument_row(default_inst):
|
|
| 100 |
output_txt = gr.Textbox(label="output", lines=10, max_lines=10)
|
| 101 |
with gr.Column(scale=1, min_width=100):
|
| 102 |
inst_audio = gr.Audio(label="Audio")
|
| 103 |
-
regenerate = gr.Checkbox(value=False, label="Regenerate")
|
| 104 |
# add_bars = gr.Checkbox(value=False, label="Add Bars")
|
| 105 |
# add_bar_count = gr.Dropdown([1, 2, 4, 8], value=1, label="Add Bars")
|
| 106 |
gen_btn = gr.Button("Generate")
|
| 107 |
gen_btn.click(
|
| 108 |
fn=generator,
|
| 109 |
-
inputs=[
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
state,
|
|
|
|
|
|
|
|
|
|
| 115 |
],
|
| 116 |
-
outputs=[output_txt, inst_audio, piano_roll, state, mixed_audio],
|
| 117 |
)
|
| 118 |
|
| 119 |
|
| 120 |
-
with gr.Blocks(
|
| 121 |
-
|
| 122 |
-
model,
|
| 123 |
-
tokenizer,
|
| 124 |
-
)
|
| 125 |
-
genesis.set_nb_bars_generated(n_bars=n_bar_generated)
|
| 126 |
state = gr.State([])
|
| 127 |
mixed_audio = gr.Audio(label="Mixed Audio")
|
| 128 |
piano_roll = gr.Plot(label="Piano Roll")
|
| 129 |
-
instrument_row("Drums")
|
| 130 |
-
instrument_row("Bass")
|
| 131 |
-
instrument_row("Synth Lead")
|
| 132 |
# instrument_row("Piano")
|
| 133 |
|
| 134 |
demo.launch(debug=True)
|
|
@@ -138,14 +161,9 @@ TODO: DEPLOY
|
|
| 138 |
TODO: temp file situation
|
| 139 |
TODO: clear cache situation
|
| 140 |
TODO: reset button
|
| 141 |
-
TODO: instrument mapping business
|
| 142 |
-
TODO: Y lim axis of piano roll
|
| 143 |
TODO: add a button to save the generated midi
|
| 144 |
TODO: add improvise button
|
| 145 |
-
TODO: making the piano roll fit on the horizontal scale
|
| 146 |
TODO: set values for temperature as it is done for density
|
| 147 |
-
TODO: set the color situation to be dark background
|
| 148 |
-
TODO: make regeration default when an intrument has already been track has already been generated
|
| 149 |
TODO: Add bar should be now set for the whole piece - regenerrate should regenerate the added bars only on all instruments
|
| 150 |
TODO: row height to fix
|
| 151 |
|
|
|
|
| 26 |
model_repo, from_huggingface=True, revision=revision
|
| 27 |
).load_model_and_tokenizer()
|
| 28 |
|
|
|
|
| 29 |
miditok = get_miditok()
|
| 30 |
decoder = TextDecoder(miditok)
|
| 31 |
|
|
|
|
| 39 |
|
| 40 |
|
| 41 |
def generator(
|
| 42 |
+
label,
|
| 43 |
+
regenerate,
|
| 44 |
+
temp,
|
| 45 |
+
density,
|
| 46 |
+
instrument,
|
| 47 |
+
state,
|
| 48 |
+
piece_by_track,
|
| 49 |
+
add_bars=False,
|
| 50 |
+
add_bar_count=1,
|
| 51 |
):
|
| 52 |
|
| 53 |
+
genesis = GenerateMidiText(model, tokenizer, piece_by_track)
|
| 54 |
+
track = {"label": label}
|
| 55 |
inst = next(
|
| 56 |
(inst for inst in INSTRUMENT_CLASSES if inst["name"] == instrument),
|
| 57 |
{"family_number": "DRUMS"},
|
| 58 |
)["family_number"]
|
| 59 |
|
| 60 |
+
inst_index = -1 # default to last generated
|
| 61 |
+
if state != []:
|
| 62 |
+
for index, instrum in enumerate(state):
|
| 63 |
+
if instrum["label"] == track["label"]:
|
| 64 |
+
inst_index = index # changing if exists
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 65 |
|
| 66 |
# Generate
|
| 67 |
if not add_bars:
|
| 68 |
+
# Regenerate
|
| 69 |
+
if regenerate:
|
| 70 |
+
state.pop(inst_index)
|
| 71 |
+
genesis.delete_one_track(inst_index)
|
| 72 |
+
|
| 73 |
+
generated_text = (
|
| 74 |
+
genesis.get_whole_piece_from_bar_dict()
|
| 75 |
+
) # maybe not useful here
|
| 76 |
+
inst_index = -1 # reset to last generated
|
| 77 |
+
|
| 78 |
# NEW TRACK
|
| 79 |
input_prompt = define_prompt(state, genesis)
|
| 80 |
generated_text = genesis.generate_one_new_track(
|
| 81 |
inst, density, temp, input_prompt=input_prompt
|
| 82 |
)
|
| 83 |
+
|
| 84 |
+
regenerate = True # set generate to true
|
| 85 |
else:
|
| 86 |
# NEW BARS
|
| 87 |
genesis.generate_n_more_bars(add_bar_count) # for all instruments
|
|
|
|
| 95 |
decoder.get_midi(inst_text, inst_midi_name)
|
| 96 |
_, inst_audio = get_music(inst_midi_name)
|
| 97 |
piano_roll = plot_piano_roll(mixed_inst_midi)
|
| 98 |
+
track["text"] = inst_text
|
| 99 |
+
state.append(track)
|
| 100 |
+
|
| 101 |
+
return (
|
| 102 |
+
inst_text,
|
| 103 |
+
(44100, inst_audio),
|
| 104 |
+
piano_roll,
|
| 105 |
+
state,
|
| 106 |
+
(44100, mixed_audio),
|
| 107 |
+
regenerate,
|
| 108 |
+
genesis.piece_by_track,
|
| 109 |
+
)
|
| 110 |
|
|
|
|
| 111 |
|
| 112 |
+
def instrument_row(default_inst, row_id):
|
| 113 |
with gr.Row():
|
| 114 |
+
row = gr.Variable(row_id)
|
| 115 |
with gr.Column(scale=1, min_width=50):
|
| 116 |
inst = gr.Dropdown(
|
| 117 |
[inst["name"] for inst in INSTRUMENT_CLASSES] + ["Drums"],
|
|
|
|
| 125 |
output_txt = gr.Textbox(label="output", lines=10, max_lines=10)
|
| 126 |
with gr.Column(scale=1, min_width=100):
|
| 127 |
inst_audio = gr.Audio(label="Audio")
|
| 128 |
+
regenerate = gr.Checkbox(value=False, label="Regenerate", visible=False)
|
| 129 |
# add_bars = gr.Checkbox(value=False, label="Add Bars")
|
| 130 |
# add_bar_count = gr.Dropdown([1, 2, 4, 8], value=1, label="Add Bars")
|
| 131 |
gen_btn = gr.Button("Generate")
|
| 132 |
gen_btn.click(
|
| 133 |
fn=generator,
|
| 134 |
+
inputs=[row, regenerate, temp, density, inst, state, piece_by_track],
|
| 135 |
+
outputs=[
|
| 136 |
+
output_txt,
|
| 137 |
+
inst_audio,
|
| 138 |
+
piano_roll,
|
| 139 |
state,
|
| 140 |
+
mixed_audio,
|
| 141 |
+
regenerate,
|
| 142 |
+
piece_by_track,
|
| 143 |
],
|
|
|
|
| 144 |
)
|
| 145 |
|
| 146 |
|
| 147 |
+
with gr.Blocks() as demo:
|
| 148 |
+
piece_by_track = gr.State([])
|
|
|
|
|
|
|
|
|
|
|
|
|
| 149 |
state = gr.State([])
|
| 150 |
mixed_audio = gr.Audio(label="Mixed Audio")
|
| 151 |
piano_roll = gr.Plot(label="Piano Roll")
|
| 152 |
+
instrument_row("Drums", 0)
|
| 153 |
+
instrument_row("Bass", 1)
|
| 154 |
+
instrument_row("Synth Lead", 2)
|
| 155 |
# instrument_row("Piano")
|
| 156 |
|
| 157 |
demo.launch(debug=True)
|
|
|
|
| 161 |
TODO: temp file situation
|
| 162 |
TODO: clear cache situation
|
| 163 |
TODO: reset button
|
|
|
|
|
|
|
| 164 |
TODO: add a button to save the generated midi
|
| 165 |
TODO: add improvise button
|
|
|
|
| 166 |
TODO: set values for temperature as it is done for density
|
|
|
|
|
|
|
| 167 |
TODO: Add bar should be now set for the whole piece - regenerrate should regenerate the added bars only on all instruments
|
| 168 |
TODO: row height to fix
|
| 169 |
|