auto-beater / app.py
Achillefs Sourlas
Small change
46e7440
import gradio as gr
from beatgenerator import BeatGenerator
from datetime import datetime
from transformers import GPT2LMHeadModel, GPT2Tokenizer
import json
from json.decoder import JSONDecodeError
STEP_COUNT = 32
INSTRUMENT_COUNT = 9
model = GPT2LMHeadModel.from_pretrained("./model")
tokenizer = GPT2Tokenizer.from_pretrained("./tokenizer")
tokenizer.pad_token = tokenizer.eos_token
beat_generator = BeatGenerator(model=model, tokenizer=tokenizer)
# def on_submit(*grid_rows) -> [str]:
# step_data_container = []
# for grid_row_id in range(INSTRUMENT_COUNT):
# grid_row_as_ints = list(map(lambda x: int(x) - 1, grid_rows[grid_row_id]))
# step_data_container.append(grid_row_as_ints)
# temperature: float = grid_rows[9]
# tempo: int = grid_rows[10]
# now = datetime.now()
# date_string = now.strftime("%Y-%m-%d_%H-%M")
# genre, midi_data = beat_generator.generate_beat(user_prompt=step_data_container, temperature=temperature, tempo=tempo)
# return ["""<div><h3>Genre: {0}</h3></div><br/><div><a href="data:audio/midi;base64,{1}" download="beat-{0}-{2}.mid">Download beat</a></div>""".format(genre, midi_data, date_string)]
# checkbox_rows = [
# ["{:02d}".format(col + 1) for col in range(STEP_COUNT)] for _ in range(INSTRUMENT_COUNT)
# ]
# inputs = [
# gr.inputs.CheckboxGroup(checkbox_rows[0], label=f"Kick"),
# gr.inputs.CheckboxGroup(checkbox_rows[1], label=f"Snare"),
# gr.inputs.CheckboxGroup(checkbox_rows[2], label=f"Clap"),
# gr.inputs.CheckboxGroup(checkbox_rows[2], label=f"Hat"),
# gr.inputs.CheckboxGroup(checkbox_rows[2], label=f"L tom"),
# gr.inputs.CheckboxGroup(checkbox_rows[2], label=f"Open hat"),
# gr.inputs.CheckboxGroup(checkbox_rows[2], label=f"M tom"),
# gr.inputs.CheckboxGroup(checkbox_rows[2], label=f"Crash cymbal"),
# gr.inputs.CheckboxGroup(checkbox_rows[2], label=f"Ride cymbal"),
# gr.inputs.Slider(minimum=0.1, maximum=1.0, step=0.1, default=0.7, label="Temperature"),
# gr.inputs.Slider(minimum=60, maximum=200, step=1, default=120, label="Tempo")
# ]
# iface = gr.Interface(
# fn=on_submit,
# inputs=inputs,
# outputs=["html"],
# title="Simple (MIDI) Beat Generator",
# description="A simple beat generator that creates an 8-bar MIDI beats on every run, based on a 32-step (2 bars) prompt in the form of a step sequencer. The generator uses a small fine-tuned GPT-2 model to recognise the genre (currently only Trap and Deep House) and generate the beat."
# )
# iface.launch()
# Create a custom block for JSON input
class JSONInput(gr.inputs.Textbox):
def preprocess(self, x):
try:
# Parse the JSON string into a Python object
parsed_json = json.loads(x)
return parsed_json
except json.JSONDecodeError:
return None
# Define your processing function
def on_did_receive_input(text: [str]):
try:
# Parse the JSON string into a Python object
# return text
input_json_value = json.loads(text.replace("'", '"'))
try:
tempo: int = input_json_value['tempo']
temperature: float = input_json_value['temperature']
user_prompt: [[int]] = input_json_value['music_data']
genre, events = beat_generator.generate_beat(
user_prompt=user_prompt, temperature=temperature, tempo=tempo
)
return json.dumps(
{
"genre": genre,
"events": events
}
)
except KeyError:
return "Error! Message was not found in JSON input"
except JSONDecodeError as e:
return "Error! Invalid JSON input: {0}".format(e)
# Create the Gradio interface using Blocks
iface = gr.Interface(
fn=on_did_receive_input,
inputs="text",
outputs="text"
)
iface.launch()