Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| from beatgenerator import BeatGenerator | |
| from datetime import datetime | |
| from transformers import GPT2LMHeadModel, GPT2Tokenizer | |
| import json | |
| from json.decoder import JSONDecodeError | |
| STEP_COUNT = 32 | |
| INSTRUMENT_COUNT = 9 | |
| model = GPT2LMHeadModel.from_pretrained("./model") | |
| tokenizer = GPT2Tokenizer.from_pretrained("./tokenizer") | |
| tokenizer.pad_token = tokenizer.eos_token | |
| beat_generator = BeatGenerator(model=model, tokenizer=tokenizer) | |
| # def on_submit(*grid_rows) -> [str]: | |
| # step_data_container = [] | |
| # for grid_row_id in range(INSTRUMENT_COUNT): | |
| # grid_row_as_ints = list(map(lambda x: int(x) - 1, grid_rows[grid_row_id])) | |
| # step_data_container.append(grid_row_as_ints) | |
| # temperature: float = grid_rows[9] | |
| # tempo: int = grid_rows[10] | |
| # now = datetime.now() | |
| # date_string = now.strftime("%Y-%m-%d_%H-%M") | |
| # genre, midi_data = beat_generator.generate_beat(user_prompt=step_data_container, temperature=temperature, tempo=tempo) | |
| # return ["""<div><h3>Genre: {0}</h3></div><br/><div><a href="data:audio/midi;base64,{1}" download="beat-{0}-{2}.mid">Download beat</a></div>""".format(genre, midi_data, date_string)] | |
| # checkbox_rows = [ | |
| # ["{:02d}".format(col + 1) for col in range(STEP_COUNT)] for _ in range(INSTRUMENT_COUNT) | |
| # ] | |
| # inputs = [ | |
| # gr.inputs.CheckboxGroup(checkbox_rows[0], label=f"Kick"), | |
| # gr.inputs.CheckboxGroup(checkbox_rows[1], label=f"Snare"), | |
| # gr.inputs.CheckboxGroup(checkbox_rows[2], label=f"Clap"), | |
| # gr.inputs.CheckboxGroup(checkbox_rows[2], label=f"Hat"), | |
| # gr.inputs.CheckboxGroup(checkbox_rows[2], label=f"L tom"), | |
| # gr.inputs.CheckboxGroup(checkbox_rows[2], label=f"Open hat"), | |
| # gr.inputs.CheckboxGroup(checkbox_rows[2], label=f"M tom"), | |
| # gr.inputs.CheckboxGroup(checkbox_rows[2], label=f"Crash cymbal"), | |
| # gr.inputs.CheckboxGroup(checkbox_rows[2], label=f"Ride cymbal"), | |
| # gr.inputs.Slider(minimum=0.1, maximum=1.0, step=0.1, default=0.7, label="Temperature"), | |
| # gr.inputs.Slider(minimum=60, maximum=200, step=1, default=120, label="Tempo") | |
| # ] | |
| # iface = gr.Interface( | |
| # fn=on_submit, | |
| # inputs=inputs, | |
| # outputs=["html"], | |
| # title="Simple (MIDI) Beat Generator", | |
| # description="A simple beat generator that creates an 8-bar MIDI beats on every run, based on a 32-step (2 bars) prompt in the form of a step sequencer. The generator uses a small fine-tuned GPT-2 model to recognise the genre (currently only Trap and Deep House) and generate the beat." | |
| # ) | |
| # iface.launch() | |
| # Create a custom block for JSON input | |
| class JSONInput(gr.inputs.Textbox): | |
| def preprocess(self, x): | |
| try: | |
| # Parse the JSON string into a Python object | |
| parsed_json = json.loads(x) | |
| return parsed_json | |
| except json.JSONDecodeError: | |
| return None | |
| # Define your processing function | |
| def on_did_receive_input(text: [str]): | |
| try: | |
| # Parse the JSON string into a Python object | |
| # return text | |
| input_json_value = json.loads(text.replace("'", '"')) | |
| try: | |
| tempo: int = input_json_value['tempo'] | |
| temperature: float = input_json_value['temperature'] | |
| user_prompt: [[int]] = input_json_value['music_data'] | |
| genre, events = beat_generator.generate_beat( | |
| user_prompt=user_prompt, temperature=temperature, tempo=tempo | |
| ) | |
| return json.dumps( | |
| { | |
| "genre": genre, | |
| "events": events | |
| } | |
| ) | |
| except KeyError: | |
| return "Error! Message was not found in JSON input" | |
| except JSONDecodeError as e: | |
| return "Error! Invalid JSON input: {0}".format(e) | |
| # Create the Gradio interface using Blocks | |
| iface = gr.Interface( | |
| fn=on_did_receive_input, | |
| inputs="text", | |
| outputs="text" | |
| ) | |
| iface.launch() |