import gradio as gr from beatgenerator import BeatGenerator from datetime import datetime from transformers import GPT2LMHeadModel, GPT2Tokenizer import json from json.decoder import JSONDecodeError STEP_COUNT = 32 INSTRUMENT_COUNT = 9 model = GPT2LMHeadModel.from_pretrained("./model") tokenizer = GPT2Tokenizer.from_pretrained("./tokenizer") tokenizer.pad_token = tokenizer.eos_token beat_generator = BeatGenerator(model=model, tokenizer=tokenizer) # def on_submit(*grid_rows) -> [str]: # step_data_container = [] # for grid_row_id in range(INSTRUMENT_COUNT): # grid_row_as_ints = list(map(lambda x: int(x) - 1, grid_rows[grid_row_id])) # step_data_container.append(grid_row_as_ints) # temperature: float = grid_rows[9] # tempo: int = grid_rows[10] # now = datetime.now() # date_string = now.strftime("%Y-%m-%d_%H-%M") # genre, midi_data = beat_generator.generate_beat(user_prompt=step_data_container, temperature=temperature, tempo=tempo) # return ["""

Genre: {0}


Download beat
""".format(genre, midi_data, date_string)] # checkbox_rows = [ # ["{:02d}".format(col + 1) for col in range(STEP_COUNT)] for _ in range(INSTRUMENT_COUNT) # ] # inputs = [ # gr.inputs.CheckboxGroup(checkbox_rows[0], label=f"Kick"), # gr.inputs.CheckboxGroup(checkbox_rows[1], label=f"Snare"), # gr.inputs.CheckboxGroup(checkbox_rows[2], label=f"Clap"), # gr.inputs.CheckboxGroup(checkbox_rows[2], label=f"Hat"), # gr.inputs.CheckboxGroup(checkbox_rows[2], label=f"L tom"), # gr.inputs.CheckboxGroup(checkbox_rows[2], label=f"Open hat"), # gr.inputs.CheckboxGroup(checkbox_rows[2], label=f"M tom"), # gr.inputs.CheckboxGroup(checkbox_rows[2], label=f"Crash cymbal"), # gr.inputs.CheckboxGroup(checkbox_rows[2], label=f"Ride cymbal"), # gr.inputs.Slider(minimum=0.1, maximum=1.0, step=0.1, default=0.7, label="Temperature"), # gr.inputs.Slider(minimum=60, maximum=200, step=1, default=120, label="Tempo") # ] # iface = gr.Interface( # fn=on_submit, # inputs=inputs, # outputs=["html"], # title="Simple (MIDI) Beat Generator", # description="A simple beat generator that creates an 8-bar MIDI beats on every run, based on a 32-step (2 bars) prompt in the form of a step sequencer. The generator uses a small fine-tuned GPT-2 model to recognise the genre (currently only Trap and Deep House) and generate the beat." # ) # iface.launch() # Create a custom block for JSON input class JSONInput(gr.inputs.Textbox): def preprocess(self, x): try: # Parse the JSON string into a Python object parsed_json = json.loads(x) return parsed_json except json.JSONDecodeError: return None # Define your processing function def on_did_receive_input(text: [str]): try: # Parse the JSON string into a Python object # return text input_json_value = json.loads(text.replace("'", '"')) try: tempo: int = input_json_value['tempo'] temperature: float = input_json_value['temperature'] user_prompt: [[int]] = input_json_value['music_data'] genre, events = beat_generator.generate_beat( user_prompt=user_prompt, temperature=temperature, tempo=tempo ) return json.dumps( { "genre": genre, "events": events } ) except KeyError: return "Error! Message was not found in JSON input" except JSONDecodeError as e: return "Error! Invalid JSON input: {0}".format(e) # Create the Gradio interface using Blocks iface = gr.Interface( fn=on_did_receive_input, inputs="text", outputs="text" ) iface.launch()