Spaces:
Build error
Build error
| #%% | |
| import openai | |
| import numpy as np | |
| import pretty_midi | |
| import re | |
| import numpy as np | |
| import os | |
| import gradio as gr | |
| openai.api_key = os.environ.get("OPENAI_API_KEY") | |
| # sample data | |
| markdown_table_sample = """8th | |
| | | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | | |
| |----|---|---|---|---|---|---|---|---| | |
| | C5 | x | x | | x | | | | | | |
| | D5 | | | | | | | | | | |
| | E5 | | | | | | | | | | |
| | F5 | | x | | | x | | | x | | |
| | G5 | | | x | x | | | x | | | |
| | A5 | | | | | | | | | | |
| | B5 | | | | | | | | | | |
| | C6 | | | | | | | | | | |
| | BD | x | | x | | | | x | | | |
| | SD | | | | x | | | | x | | |
| | CH | x | | x | | x | | x | | | |
| | OH | | | | x | | | x | | | |
| | LT | | | | | | x | | | | |
| | MT | | x | | | x | | | | | |
| | HT | x | | | x | | | | | | |
| """ | |
| markdown_table_sample2 = """16th | |
| | | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10| 11| 12| 13| 14| 15| 16| | |
| |----|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---| | |
| | BD | x | | x | | | | x | | x | | x | | x | | x | | | |
| | SD | | | | x | | | | x | | | x | | | | x | | | |
| | CH | x | | x | | x | | x | | x | | x | | x | | x | | | |
| | OH | | | | x | | | x | | | | | x | | | x | | | |
| | LT | | | | | | x | | | | | | | | x | | | | |
| | MT | | x | | | x | | | | | x | | | x | | | | | |
| | HT | x | | | x | | | | | x | | | x | | | | | | |
| """ | |
| MIDI_NOTENUM = { | |
| "BD": 36, | |
| "SD": 38, | |
| "CH": 42, | |
| "HH": 44, | |
| "OH": 46, | |
| "LT": 48, | |
| "MT": 48, | |
| "HT": 50, | |
| "CP": 50, | |
| "CB": 56, | |
| } | |
| SR = 44100 | |
| MAX_QUERY = 5 | |
| def convert_table_to_audio(markdown_table, resolution=8, bpm = 90.0): | |
| # convert table to array | |
| rhythm_pattern = [] | |
| for line in markdown_table.split('\n')[2:]: | |
| rhythm_pattern.append(line.split('|')[1:-1]) | |
| print(rhythm_pattern) | |
| # table to MIDI | |
| pm = pretty_midi.PrettyMIDI(initial_tempo=bpm) # midi object | |
| pm_inst = pretty_midi.Instrument(0, is_drum=True) # midi instrument | |
| pm.instruments.append(pm_inst) | |
| note_length = (60. / bpm) * (4.0 / resolution) # note duration | |
| beat_num = resolution | |
| for i in range(len(rhythm_pattern)): | |
| for j in range(1, len(rhythm_pattern[i])): | |
| beat_num = j # for looping | |
| inst = rhythm_pattern[i][0].strip().upper() | |
| velocity = 0 | |
| if 'x' == rhythm_pattern[i][j].strip(): | |
| velocity = 120 | |
| if 'o' == rhythm_pattern[i][j].strip(): | |
| velocity = 65 | |
| if velocity > 0: | |
| if inst in MIDI_NOTENUM.keys(): | |
| midinote = MIDI_NOTENUM[inst] | |
| else: | |
| midinote = pretty_midi.note_name_to_number(inst) | |
| note = pretty_midi.Note(velocity=velocity, pitch=midinote, start=note_length * (j-1)+0.0001, end=note_length * j) | |
| pm_inst.notes.append(note) | |
| # convert to audio | |
| audio_data = pm.fluidsynth() | |
| # cut off the reverb section | |
| audio_data = audio_data[:int(SR*note_length*beat_num)] # for looping, cut the tail | |
| return audio_data | |
| def get_answer(question): | |
| response = openai.ChatCompletion.create( | |
| model="gpt-3.5-turbo", | |
| messages=[ | |
| {"role": "system", "content": "You are a music generator. "}, | |
| {"role": "user", "content": "Please generate a music in a Markdown table. Time resolution is the 8th note. You use the following drums. Kick drum:BD, Snare drum:SD, Closed-hihat:CH, Open-hihat:OH, Low-tom:LT, Mid-tom:MT, High-tom:HT. And you also use MIDI note number, in format '(note)(accidental)(octave number)' (e.g. C5, D5, E5, C#4). use 'x' for an accented instrument, 'o' for a weak instrument. You need to write the time resolution first."}, | |
| {"role": "assistant", "content": markdown_table_sample}, | |
| # {"role": "user", "content": "Please generate a music pattern. The resolution is the fourth note. You use the following drums. Kick drum:BD, Snare drum:SD, Closed-hihat:CH, Open-hihat:OH, Low-tom:LT, Mid-tom:MT, High-tom:HT. use 'x' for an accented beat, 'o' for a weak beat. You need to write the time resolution first."}, | |
| # {"role": "assistant", "content": markdown_table_sample}, | |
| {"role": "user", "content": question} | |
| ] | |
| ) | |
| return response["choices"][0]["message"]["content"] | |
| def generate_rhythm(query, state): | |
| print(state) | |
| if state["gen_count"] > MAX_QUERY and len(state["user_token"]) == 0: | |
| return [None, "You need to set your ChatGPT API Key to try more than %d times" % MAX_QUERY] | |
| state["gen_count"] = state["gen_count"] + 1 | |
| # get respance from ChatGPT | |
| text_output = get_answer(query) | |
| # Try to use the first row as time resolution | |
| resolution_text = text_output.split('|')[0] | |
| try: | |
| resolution_text = re.findall(r'\d+', resolution_text)[0] | |
| resolution = int(resolution_text) | |
| except: | |
| resolution = 8 # default | |
| # Extract rhythm table | |
| table = "|" + "|".join(text_output.split('|')[1:-1]) + "|" | |
| audio_data = convert_table_to_audio(table, resolution) | |
| # loop x2 | |
| audio_data = np.tile(audio_data, 4) | |
| return [(SR, audio_data), text_output] | |
| # %% | |
| def on_token_change(user_token, state): | |
| print(user_token) | |
| openai.api_key = user_token or os.environ.get("OPENAI_API_KEY") | |
| state["user_token"] = user_token | |
| return state | |
| with gr.Blocks() as demo: | |
| state = gr.State({"gen_count": 0, "user_token":""}) | |
| with gr.Row(): | |
| with gr.Column(): | |
| # gr.Markdown("Ask ChatGPT to generate rhythm patterns") | |
| gr.Markdown("***Hey TR-ChatGPT, give me a music!***") | |
| gr.Markdown("You use the following drums. Kick drum:BD, Snare drum:SD, Closed-hihat:CH, Open-hihat:OH, Low-tom:LT, Mid-tom:MT, High-tom:HT. \ | |
| And you also use MIDI note number, in format '(note)(accidental)(octave number)' (e.g. C5, D5, E5, C#4). use 'x' for an accented instrument, 'o' for a weak instrument.", elem_id="label") | |
| with gr.Row(): | |
| with gr.Column(): | |
| inp = gr.Textbox(placeholder="Give me a Nice music!") | |
| btn = gr.Button("Generate") | |
| with gr.Column(): | |
| out_audio = gr.Audio() | |
| out_text = gr.Textbox(placeholder="ChatGPT output") | |
| with gr.Row(): | |
| with gr.Column(): | |
| gr.Markdown("Enter your own OpenAI API Key to try out more than 5 times. You can get it [here](https://platform.openai.com/account/api-keys).") | |
| user_token = gr.Textbox(placeholder="OpenAI API Key", type="password", show_label=False) | |
| btn.click(fn=generate_rhythm, inputs=[inp, state], outputs=[out_audio, out_text]) | |
| user_token.change(on_token_change, inputs=[user_token, state], outputs=[state]) | |
| demo.launch() | |