Spaces:
Runtime error
Runtime error
Achillefs Sourlas
commited on
Commit
·
31adff7
1
Parent(s):
2ac38d9
Initial commit
Browse files- .gitignore +2 -0
- app.py +100 -0
- beatgenerator.py +81 -0
- customtokenencoderdecoder.py +203 -0
- model/config.json +39 -0
- model/generation_config.json +6 -0
- model/pytorch_model.bin +3 -0
- model/training_args.bin +3 -0
- requirements.txt +4 -0
- tokenizer/added_tokens.json +66 -0
- tokenizer/merges.txt +0 -0
- tokenizer/special_tokens_map.json +24 -0
- tokenizer/tokenizer_config.json +33 -0
- tokenizer/vocab.json +0 -0
.gitignore
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
venv/
|
| 2 |
+
__pycache__/
|
app.py
ADDED
|
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
from beatgenerator import BeatGenerator
|
| 3 |
+
from datetime import datetime
|
| 4 |
+
from transformers import GPT2LMHeadModel, GPT2Tokenizer
|
| 5 |
+
import json
|
| 6 |
+
|
| 7 |
+
STEP_COUNT = 32
|
| 8 |
+
INSTRUMENT_COUNT = 9
|
| 9 |
+
|
| 10 |
+
model = GPT2LMHeadModel.from_pretrained("./model")
|
| 11 |
+
tokenizer = GPT2Tokenizer.from_pretrained("./tokenizer")
|
| 12 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 13 |
+
beat_generator = BeatGenerator(model=model, tokenizer=tokenizer)
|
| 14 |
+
|
| 15 |
+
# def on_submit(*grid_rows) -> [str]:
|
| 16 |
+
# step_data_container = []
|
| 17 |
+
|
| 18 |
+
# for grid_row_id in range(INSTRUMENT_COUNT):
|
| 19 |
+
# grid_row_as_ints = list(map(lambda x: int(x) - 1, grid_rows[grid_row_id]))
|
| 20 |
+
# step_data_container.append(grid_row_as_ints)
|
| 21 |
+
|
| 22 |
+
# temperature: float = grid_rows[9]
|
| 23 |
+
# tempo: int = grid_rows[10]
|
| 24 |
+
# now = datetime.now()
|
| 25 |
+
# date_string = now.strftime("%Y-%m-%d_%H-%M")
|
| 26 |
+
|
| 27 |
+
# genre, midi_data = beat_generator.generate_beat(user_prompt=step_data_container, temperature=temperature, tempo=tempo)
|
| 28 |
+
|
| 29 |
+
# return ["""<div><h3>Genre: {0}</h3></div><br/><div><a href="data:audio/midi;base64,{1}" download="beat-{0}-{2}.mid">Download beat</a></div>""".format(genre, midi_data, date_string)]
|
| 30 |
+
|
| 31 |
+
# checkbox_rows = [
|
| 32 |
+
# ["{:02d}".format(col + 1) for col in range(STEP_COUNT)] for _ in range(INSTRUMENT_COUNT)
|
| 33 |
+
# ]
|
| 34 |
+
|
| 35 |
+
# inputs = [
|
| 36 |
+
# gr.inputs.CheckboxGroup(checkbox_rows[0], label=f"Kick"),
|
| 37 |
+
# gr.inputs.CheckboxGroup(checkbox_rows[1], label=f"Snare"),
|
| 38 |
+
# gr.inputs.CheckboxGroup(checkbox_rows[2], label=f"Clap"),
|
| 39 |
+
# gr.inputs.CheckboxGroup(checkbox_rows[2], label=f"Hat"),
|
| 40 |
+
# gr.inputs.CheckboxGroup(checkbox_rows[2], label=f"L tom"),
|
| 41 |
+
# gr.inputs.CheckboxGroup(checkbox_rows[2], label=f"Open hat"),
|
| 42 |
+
# gr.inputs.CheckboxGroup(checkbox_rows[2], label=f"M tom"),
|
| 43 |
+
# gr.inputs.CheckboxGroup(checkbox_rows[2], label=f"Crash cymbal"),
|
| 44 |
+
# gr.inputs.CheckboxGroup(checkbox_rows[2], label=f"Ride cymbal"),
|
| 45 |
+
# gr.inputs.Slider(minimum=0.1, maximum=1.0, step=0.1, default=0.7, label="Temperature"),
|
| 46 |
+
# gr.inputs.Slider(minimum=60, maximum=200, step=1, default=120, label="Tempo")
|
| 47 |
+
# ]
|
| 48 |
+
|
| 49 |
+
# iface = gr.Interface(
|
| 50 |
+
# fn=on_submit,
|
| 51 |
+
# inputs=inputs,
|
| 52 |
+
# outputs=["html"],
|
| 53 |
+
# title="Simple (MIDI) Beat Generator",
|
| 54 |
+
# description="A simple beat generator that creates an 8-bar MIDI beats on every run, based on a 32-step (2 bars) prompt in the form of a step sequencer. The generator uses a small fine-tuned GPT-2 model to recognise the genre (currently only Trap and Deep House) and generate the beat."
|
| 55 |
+
# )
|
| 56 |
+
|
| 57 |
+
# iface.launch()
|
| 58 |
+
|
| 59 |
+
# Create a custom block for JSON input
|
| 60 |
+
class JSONInput(gr.inputs.Textbox):
|
| 61 |
+
def preprocess(self, x):
|
| 62 |
+
try:
|
| 63 |
+
# Parse the JSON string into a Python object
|
| 64 |
+
parsed_json = json.loads(x)
|
| 65 |
+
return parsed_json
|
| 66 |
+
except json.JSONDecodeError:
|
| 67 |
+
return None
|
| 68 |
+
|
| 69 |
+
# Define your processing function
|
| 70 |
+
def on_did_receive_input(text: [str]):
|
| 71 |
+
try:
|
| 72 |
+
# Parse the JSON string into a Python object
|
| 73 |
+
# return text
|
| 74 |
+
input_json_value = json.loads(text)
|
| 75 |
+
|
| 76 |
+
try:
|
| 77 |
+
tempo: int = input_json_value["tempo"]
|
| 78 |
+
temperature: float = input_json_value["temperature"]
|
| 79 |
+
data: [[int]] = input_json_value["music_data"]
|
| 80 |
+
except KeyError:
|
| 81 |
+
return "Error! Message was not found in JSON input"
|
| 82 |
+
|
| 83 |
+
dict = {}
|
| 84 |
+
dict["tempo"] = tempo
|
| 85 |
+
dict["temperature"] = temperature
|
| 86 |
+
dict["music_data"] = data
|
| 87 |
+
|
| 88 |
+
return json.dumps(dict)
|
| 89 |
+
|
| 90 |
+
except json.JSONDecodeError:
|
| 91 |
+
return "Error! Invalid JSON input"
|
| 92 |
+
|
| 93 |
+
# Create the Gradio interface using Blocks
|
| 94 |
+
iface = gr.Interface(
|
| 95 |
+
fn=on_did_receive_input,
|
| 96 |
+
inputs="text",
|
| 97 |
+
outputs="text"
|
| 98 |
+
)
|
| 99 |
+
|
| 100 |
+
iface.launch()
|
beatgenerator.py
ADDED
|
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from midiutil import MIDIFile
|
| 2 |
+
import base64
|
| 3 |
+
from io import BytesIO
|
| 4 |
+
from transformers import GPT2LMHeadModel, GPT2Tokenizer
|
| 5 |
+
from customtokenencoderdecoder import CustomTokenEncoderDecoder
|
| 6 |
+
|
| 7 |
+
class BeatGenerator:
|
| 8 |
+
STEP_SIZE = 0.25
|
| 9 |
+
STEPS_PER_SEQUENCE = 32
|
| 10 |
+
|
| 11 |
+
def __init__(self, model: GPT2LMHeadModel, tokenizer: GPT2Tokenizer):
|
| 12 |
+
self.__model = model
|
| 13 |
+
self.__tokenizer = tokenizer
|
| 14 |
+
self.__sections = ["a", "b", "c", "d"]
|
| 15 |
+
|
| 16 |
+
def generate_beat(self, user_prompt: [[int]], temperature: float, tempo: float) -> [str, str]:
|
| 17 |
+
# pitches = [36, 38, 42]
|
| 18 |
+
pitches = [36, 38, 39, 42, 45, 46, 47, 49, 51]
|
| 19 |
+
assert len(user_prompt) == len(pitches), "User prompt length must be equal to the number of pitches"
|
| 20 |
+
|
| 21 |
+
user_events: [[int, int]] = []
|
| 22 |
+
for pitch_id, pitch in enumerate(pitches):
|
| 23 |
+
for step in user_prompt[pitch_id]:
|
| 24 |
+
user_events.append((step, pitch))
|
| 25 |
+
|
| 26 |
+
custom_token_encoder_decoder = CustomTokenEncoderDecoder(
|
| 27 |
+
events=user_events,
|
| 28 |
+
sections=self.__sections,
|
| 29 |
+
steps_per_section=self.STEPS_PER_SEQUENCE,
|
| 30 |
+
model=self.__model,
|
| 31 |
+
tokenizer=self.__tokenizer,
|
| 32 |
+
)
|
| 33 |
+
|
| 34 |
+
result = custom_token_encoder_decoder.generate_events(temperature=temperature)
|
| 35 |
+
|
| 36 |
+
genre = result["genre"]
|
| 37 |
+
events = result["events"]
|
| 38 |
+
|
| 39 |
+
midi_buffer = self.__make_midi_buffer(
|
| 40 |
+
data_container=events,
|
| 41 |
+
tempo=tempo,
|
| 42 |
+
verbose=False
|
| 43 |
+
)
|
| 44 |
+
midi_base64 = base64.b64encode(midi_buffer.read()).decode("utf-8")
|
| 45 |
+
|
| 46 |
+
return genre, midi_base64
|
| 47 |
+
|
| 48 |
+
def __make_midi_buffer(self, data_container: [(int, int)], tempo: int, verbose: bool = False) -> BytesIO:
|
| 49 |
+
track_count = 1
|
| 50 |
+
out_midi_file = MIDIFile(1)
|
| 51 |
+
out_midi_file.addTempo(0, 0, tempo)
|
| 52 |
+
|
| 53 |
+
for data in data_container:
|
| 54 |
+
step = data[0]
|
| 55 |
+
pitch = data[1]
|
| 56 |
+
velocity = 100
|
| 57 |
+
|
| 58 |
+
if verbose is True:
|
| 59 |
+
print("Processing: {0} in step range: {1}".format(data, step_ranges[section_id]))
|
| 60 |
+
|
| 61 |
+
if step >= 0 and step < 128 and pitch >= 0 and pitch < 128:
|
| 62 |
+
start_time = float(step) * self.STEP_SIZE
|
| 63 |
+
volume = int(velocity)
|
| 64 |
+
|
| 65 |
+
out_midi_file.addNote(
|
| 66 |
+
track=0,
|
| 67 |
+
channel=9,
|
| 68 |
+
pitch=pitch,
|
| 69 |
+
time=start_time,
|
| 70 |
+
duration=self.STEP_SIZE,
|
| 71 |
+
volume=volume
|
| 72 |
+
)
|
| 73 |
+
|
| 74 |
+
buffer = BytesIO()
|
| 75 |
+
out_midi_file.writeFile(buffer)
|
| 76 |
+
buffer.seek(0)
|
| 77 |
+
|
| 78 |
+
with open("out.mid", "wb") as output_file:
|
| 79 |
+
out_midi_file.writeFile(output_file)
|
| 80 |
+
|
| 81 |
+
return buffer
|
customtokenencoderdecoder.py
ADDED
|
@@ -0,0 +1,203 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from transformers import GPT2Tokenizer, GPT2LMHeadModel
|
| 2 |
+
|
| 3 |
+
class CustomTokenEncoderDecoder:
|
| 4 |
+
CUSTOM_CLASSIFICATION_TOKEN = "which_genre_section"
|
| 5 |
+
|
| 6 |
+
def __init__(self, events: [[int, int]], sections: [str], steps_per_section: int, model: GPT2LMHeadModel, tokenizer: GPT2Tokenizer):
|
| 7 |
+
self.__model = model
|
| 8 |
+
self.__tokenizer = tokenizer
|
| 9 |
+
self.__events = events
|
| 10 |
+
self.__steps_per_section = steps_per_section
|
| 11 |
+
self.__sections = sections
|
| 12 |
+
self.__events_tokens = self.events_to_tokens(events)
|
| 13 |
+
|
| 14 |
+
def is_step_token(self, token: str) -> bool:
|
| 15 |
+
return token.startswith("step:")
|
| 16 |
+
|
| 17 |
+
def is_pitch_token(self, token: str) -> bool:
|
| 18 |
+
return token.startswith("pitch:")
|
| 19 |
+
|
| 20 |
+
def is_genre_token(self, token: str) -> bool:
|
| 21 |
+
return token.startswith("genre:")
|
| 22 |
+
|
| 23 |
+
def is_section_token(self, token: str) -> bool:
|
| 24 |
+
return token.startswith("section:")
|
| 25 |
+
|
| 26 |
+
def token_to_pitch(self, token: str) -> int:
|
| 27 |
+
return int(token.split(":")[1])
|
| 28 |
+
|
| 29 |
+
def token_to_step(self, token: str) -> int:
|
| 30 |
+
return int(token.split(":")[1])
|
| 31 |
+
|
| 32 |
+
def token_to_section(self, token: str) -> str:
|
| 33 |
+
return token.split(":")[1]
|
| 34 |
+
|
| 35 |
+
def token_to_genre(self, token: str) -> str:
|
| 36 |
+
return token.split(":")[1]
|
| 37 |
+
|
| 38 |
+
def pitch_to_token(self, pitch: int) -> str:
|
| 39 |
+
return "pitch:{0}".format(pitch)
|
| 40 |
+
|
| 41 |
+
def step_to_token(self, step: int) -> [str]:
|
| 42 |
+
return "step:{0}".format(step)
|
| 43 |
+
|
| 44 |
+
def section_to_token(self, section: str) -> [str]:
|
| 45 |
+
return "section:{0}".format(section)
|
| 46 |
+
|
| 47 |
+
def events_to_tokens(self, events: [[int, int]]) -> [str]:
|
| 48 |
+
result: [str] = []
|
| 49 |
+
|
| 50 |
+
for step_id in range(self.__steps_per_section):
|
| 51 |
+
step_data = list(filter(lambda x: x[0] == step_id, events))
|
| 52 |
+
|
| 53 |
+
if len(step_data) > 0:
|
| 54 |
+
result.append(self.step_to_token(step_id))
|
| 55 |
+
step_tokens = list(map(lambda x: self.pitch_to_token(x[1]), step_data))
|
| 56 |
+
if len(step_tokens) > 0:
|
| 57 |
+
result += step_tokens
|
| 58 |
+
|
| 59 |
+
return result
|
| 60 |
+
|
| 61 |
+
def tokens_to_classification_prompt(self, tokens: [str]) -> str:
|
| 62 |
+
return " ".join(tokens + [self.CUSTOM_CLASSIFICATION_TOKEN])
|
| 63 |
+
|
| 64 |
+
def tokens_to_section_prompt(self, tokens: [str], section: str, prompted_section: str) -> str:
|
| 65 |
+
return " ".join([self.section_to_token(section)] + tokens + [self.section_to_token(prompted_section)])
|
| 66 |
+
|
| 67 |
+
def tokens_to_genre_section(self, tokens: [str]) -> dict:
|
| 68 |
+
genre: str = ""
|
| 69 |
+
section: str = ""
|
| 70 |
+
|
| 71 |
+
for token in tokens:
|
| 72 |
+
if self.is_genre_token(token):
|
| 73 |
+
genre = self.token_to_genre(token)
|
| 74 |
+
elif self.is_section_token(token):
|
| 75 |
+
section = self.token_to_section(token)
|
| 76 |
+
|
| 77 |
+
return { "genre": genre, "section": section }
|
| 78 |
+
|
| 79 |
+
def section_to_step_offset(self, section: str) -> int:
|
| 80 |
+
if section == "a":
|
| 81 |
+
return 0
|
| 82 |
+
elif section == "b":
|
| 83 |
+
return self.__steps_per_section
|
| 84 |
+
elif section == "c":
|
| 85 |
+
return 2 * self.__steps_per_section
|
| 86 |
+
elif section == "d":
|
| 87 |
+
return 3 * self.__steps_per_section
|
| 88 |
+
else:
|
| 89 |
+
raise Exception("Invalid section: {0}".format(section))
|
| 90 |
+
|
| 91 |
+
def tokens_to_section_events(self, tokens: [str], section: str, step_offset: int = None) -> [[int, int]]:
|
| 92 |
+
for (token_id, token) in enumerate(tokens):
|
| 93 |
+
if self.is_section_token(token):
|
| 94 |
+
if self.token_to_section(token) == section:
|
| 95 |
+
offset: int = self.section_to_step_offset(section)
|
| 96 |
+
if step_offset is not None:
|
| 97 |
+
offset = step_offset
|
| 98 |
+
return self.tokens_to_events(tokens=tokens[token_id:], step_offset=offset)
|
| 99 |
+
|
| 100 |
+
raise Exception("Section {0} not found in tokens".format(section))
|
| 101 |
+
|
| 102 |
+
def tokens_to_events(self, tokens: [str], step_offset: int) -> [[int, int]]:
|
| 103 |
+
result: [[int, int]] = []
|
| 104 |
+
|
| 105 |
+
for (token_id, token) in enumerate(tokens):
|
| 106 |
+
if self.is_step_token(token):
|
| 107 |
+
step = self.token_to_step(token) + step_offset
|
| 108 |
+
next_token_id = token_id + 1
|
| 109 |
+
|
| 110 |
+
while next_token_id < len(tokens) and self.is_pitch_token(tokens[next_token_id]):
|
| 111 |
+
pitch = self.token_to_pitch(tokens[next_token_id])
|
| 112 |
+
result.append((step, pitch))
|
| 113 |
+
next_token_id += 1
|
| 114 |
+
|
| 115 |
+
return result
|
| 116 |
+
|
| 117 |
+
def convert_events_to_section_events(self, events: [[int, int]], section: str) -> [[int, int]]:
|
| 118 |
+
offset = self.step_offset_for_section(section)
|
| 119 |
+
return list(map(lambda x: (x[0] + offset, x[1]), events))
|
| 120 |
+
|
| 121 |
+
def generate_events(self, temperature: float) -> dict:
|
| 122 |
+
genre_section_data = self.make_classification_inference(temperature=temperature)
|
| 123 |
+
genre = genre_section_data["genre"]
|
| 124 |
+
section = genre_section_data["section"]
|
| 125 |
+
print("Classification results")
|
| 126 |
+
print("======================")
|
| 127 |
+
print("Found genre: {0}".format(genre))
|
| 128 |
+
print("Found section: {0}".format(section))
|
| 129 |
+
print("======================")
|
| 130 |
+
|
| 131 |
+
all_events: [[int, int]] = []
|
| 132 |
+
|
| 133 |
+
all_events += list(map(lambda x: (x[0] + self.section_to_step_offset(section=section), x[1]) ,self.__events))
|
| 134 |
+
|
| 135 |
+
if section not in self.__sections:
|
| 136 |
+
raise Exception("Section {0} not found in sections".format(section))
|
| 137 |
+
|
| 138 |
+
other_sections = list(filter(lambda x: x != section, self.__sections))
|
| 139 |
+
for other_section in other_sections:
|
| 140 |
+
prompt = self.tokens_to_section_prompt(tokens=self.__events_tokens, section=section, prompted_section=other_section)
|
| 141 |
+
events = self.make_section_events_inference(prompt=prompt, temperature=temperature, section=other_section, known_section=section)
|
| 142 |
+
all_events += events
|
| 143 |
+
|
| 144 |
+
return {
|
| 145 |
+
"events": all_events,
|
| 146 |
+
"genre": genre
|
| 147 |
+
}
|
| 148 |
+
|
| 149 |
+
def tokens_to_genre_and_section_information(self, tokens: [str]) -> dict:
|
| 150 |
+
genre: str = ""
|
| 151 |
+
section: str = ""
|
| 152 |
+
|
| 153 |
+
for token in tokens:
|
| 154 |
+
if self.is_genre_token(token):
|
| 155 |
+
genre = self.token_to_genre(token)
|
| 156 |
+
elif self.is_section_token(token):
|
| 157 |
+
section = self.token_to_section(token)
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
return { "genre": genre, "section": section }
|
| 161 |
+
|
| 162 |
+
def make_classification_inference(self, temperature: float) -> dict:
|
| 163 |
+
genre_and_section_prompt = self.tokens_to_classification_prompt(self.__events_tokens)
|
| 164 |
+
prompt = self.__tokenizer.encode(genre_and_section_prompt, add_special_tokens=True, return_tensors="pt")
|
| 165 |
+
|
| 166 |
+
generated_section_genre_sequence = self.__model.generate(
|
| 167 |
+
prompt,
|
| 168 |
+
max_length=1024,
|
| 169 |
+
do_sample=True,
|
| 170 |
+
temperature=0.1,
|
| 171 |
+
num_return_sequences=1,
|
| 172 |
+
)
|
| 173 |
+
|
| 174 |
+
section_genre_result = self.__tokenizer.decode(generated_section_genre_sequence[0], skip_special_tokens=True)
|
| 175 |
+
assert len(section_genre_result) > 0, "Empty result"
|
| 176 |
+
|
| 177 |
+
genre_section_data = self.tokens_to_genre_and_section_information(section_genre_result.split(" "))
|
| 178 |
+
return genre_section_data
|
| 179 |
+
|
| 180 |
+
def make_section_events_inference(self, prompt: str, section: str, temperature: float, known_section: str) -> [[int, int]]:
|
| 181 |
+
tokenised_prompt = self.__tokenizer.encode(prompt, add_special_tokens=True, return_tensors="pt")
|
| 182 |
+
assert len(tokenised_prompt[0]) <= 1024, "Prompt length exceeds maximum sequence length"
|
| 183 |
+
|
| 184 |
+
generated_sequence = self.__model.generate(
|
| 185 |
+
tokenised_prompt,
|
| 186 |
+
max_length=1024,
|
| 187 |
+
do_sample=True,
|
| 188 |
+
temperature=temperature,
|
| 189 |
+
num_return_sequences=1,
|
| 190 |
+
)
|
| 191 |
+
|
| 192 |
+
result = self.__tokenizer.decode(
|
| 193 |
+
generated_sequence[0], skip_special_tokens=True
|
| 194 |
+
)
|
| 195 |
+
|
| 196 |
+
events = self.tokens_to_section_events(tokens=result.split(" "), section=section)
|
| 197 |
+
# Fallback option when inference fails (sometimes the model generates a sequence that doesn't contain the section)
|
| 198 |
+
if len(events) == 0:
|
| 199 |
+
events = self.tokens_to_section_events(tokens=result.split(" "), section=known_section, step_offset=self.section_to_step_offset(section=section))
|
| 200 |
+
|
| 201 |
+
assert len(events) > 0, "Empty result"
|
| 202 |
+
|
| 203 |
+
return events
|
model/config.json
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_name_or_path": "gpt2",
|
| 3 |
+
"activation_function": "gelu_new",
|
| 4 |
+
"architectures": [
|
| 5 |
+
"GPT2LMHeadModel"
|
| 6 |
+
],
|
| 7 |
+
"attn_pdrop": 0.1,
|
| 8 |
+
"bos_token_id": 50256,
|
| 9 |
+
"embd_pdrop": 0.1,
|
| 10 |
+
"eos_token_id": 50256,
|
| 11 |
+
"initializer_range": 0.02,
|
| 12 |
+
"layer_norm_epsilon": 1e-05,
|
| 13 |
+
"model_type": "gpt2",
|
| 14 |
+
"n_ctx": 1024,
|
| 15 |
+
"n_embd": 768,
|
| 16 |
+
"n_head": 12,
|
| 17 |
+
"n_inner": null,
|
| 18 |
+
"n_layer": 12,
|
| 19 |
+
"n_positions": 1024,
|
| 20 |
+
"reorder_and_upcast_attn": false,
|
| 21 |
+
"resid_pdrop": 0.1,
|
| 22 |
+
"scale_attn_by_inverse_layer_idx": false,
|
| 23 |
+
"scale_attn_weights": true,
|
| 24 |
+
"summary_activation": null,
|
| 25 |
+
"summary_first_dropout": 0.1,
|
| 26 |
+
"summary_proj_to_labels": true,
|
| 27 |
+
"summary_type": "cls_index",
|
| 28 |
+
"summary_use_proj": true,
|
| 29 |
+
"task_specific_params": {
|
| 30 |
+
"text-generation": {
|
| 31 |
+
"do_sample": true,
|
| 32 |
+
"max_length": 50
|
| 33 |
+
}
|
| 34 |
+
},
|
| 35 |
+
"torch_dtype": "float32",
|
| 36 |
+
"transformers_version": "4.28.1",
|
| 37 |
+
"use_cache": true,
|
| 38 |
+
"vocab_size": 50321
|
| 39 |
+
}
|
model/generation_config.json
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_from_model_config": true,
|
| 3 |
+
"bos_token_id": 50256,
|
| 4 |
+
"eos_token_id": 50256,
|
| 5 |
+
"transformers_version": "4.28.1"
|
| 6 |
+
}
|
model/pytorch_model.bin
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:f1e51cee355f39d25c000d17d50c4313dd1787f086ce23191b8a495f9c33a82b
|
| 3 |
+
size 510594621
|
model/training_args.bin
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:94b82dda5c87ea468fb088c62a5c04ce9158aed8f35b34ed6e7ab193f4cb4c8f
|
| 3 |
+
size 3707
|
requirements.txt
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
gradio
|
| 2 |
+
MIDIUtil
|
| 3 |
+
transformers
|
| 4 |
+
torch
|
tokenizer/added_tokens.json
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"\n\n###\n\n": 50320,
|
| 3 |
+
"genre:DHouse": 50314,
|
| 4 |
+
"genre:Trap": 50313,
|
| 5 |
+
"pitch:36": 50289,
|
| 6 |
+
"pitch:37": 50290,
|
| 7 |
+
"pitch:38": 50291,
|
| 8 |
+
"pitch:39": 50292,
|
| 9 |
+
"pitch:40": 50293,
|
| 10 |
+
"pitch:41": 50294,
|
| 11 |
+
"pitch:42": 50295,
|
| 12 |
+
"pitch:43": 50296,
|
| 13 |
+
"pitch:44": 50297,
|
| 14 |
+
"pitch:45": 50298,
|
| 15 |
+
"pitch:46": 50299,
|
| 16 |
+
"pitch:47": 50300,
|
| 17 |
+
"pitch:48": 50301,
|
| 18 |
+
"pitch:49": 50302,
|
| 19 |
+
"pitch:50": 50303,
|
| 20 |
+
"pitch:51": 50304,
|
| 21 |
+
"pitch:52": 50305,
|
| 22 |
+
"pitch:53": 50306,
|
| 23 |
+
"pitch:54": 50307,
|
| 24 |
+
"pitch:55": 50308,
|
| 25 |
+
"pitch:56": 50309,
|
| 26 |
+
"pitch:57": 50310,
|
| 27 |
+
"pitch:58": 50311,
|
| 28 |
+
"pitch:59": 50312,
|
| 29 |
+
"section:a": 50315,
|
| 30 |
+
"section:b": 50316,
|
| 31 |
+
"section:c": 50317,
|
| 32 |
+
"section:d": 50318,
|
| 33 |
+
"step:0": 50257,
|
| 34 |
+
"step:1": 50258,
|
| 35 |
+
"step:10": 50267,
|
| 36 |
+
"step:11": 50268,
|
| 37 |
+
"step:12": 50269,
|
| 38 |
+
"step:13": 50270,
|
| 39 |
+
"step:14": 50271,
|
| 40 |
+
"step:15": 50272,
|
| 41 |
+
"step:16": 50273,
|
| 42 |
+
"step:17": 50274,
|
| 43 |
+
"step:18": 50275,
|
| 44 |
+
"step:19": 50276,
|
| 45 |
+
"step:2": 50259,
|
| 46 |
+
"step:20": 50277,
|
| 47 |
+
"step:21": 50278,
|
| 48 |
+
"step:22": 50279,
|
| 49 |
+
"step:23": 50280,
|
| 50 |
+
"step:24": 50281,
|
| 51 |
+
"step:25": 50282,
|
| 52 |
+
"step:26": 50283,
|
| 53 |
+
"step:27": 50284,
|
| 54 |
+
"step:28": 50285,
|
| 55 |
+
"step:29": 50286,
|
| 56 |
+
"step:3": 50260,
|
| 57 |
+
"step:30": 50287,
|
| 58 |
+
"step:31": 50288,
|
| 59 |
+
"step:4": 50261,
|
| 60 |
+
"step:5": 50262,
|
| 61 |
+
"step:6": 50263,
|
| 62 |
+
"step:7": 50264,
|
| 63 |
+
"step:8": 50265,
|
| 64 |
+
"step:9": 50266,
|
| 65 |
+
"which_genre_section": 50319
|
| 66 |
+
}
|
tokenizer/merges.txt
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
tokenizer/special_tokens_map.json
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"bos_token": {
|
| 3 |
+
"content": "<|endoftext|>",
|
| 4 |
+
"lstrip": false,
|
| 5 |
+
"normalized": true,
|
| 6 |
+
"rstrip": false,
|
| 7 |
+
"single_word": false
|
| 8 |
+
},
|
| 9 |
+
"eos_token": {
|
| 10 |
+
"content": "<|endoftext|>",
|
| 11 |
+
"lstrip": false,
|
| 12 |
+
"normalized": true,
|
| 13 |
+
"rstrip": false,
|
| 14 |
+
"single_word": false
|
| 15 |
+
},
|
| 16 |
+
"pad_token": "<|endoftext|>",
|
| 17 |
+
"unk_token": {
|
| 18 |
+
"content": "<|endoftext|>",
|
| 19 |
+
"lstrip": false,
|
| 20 |
+
"normalized": true,
|
| 21 |
+
"rstrip": false,
|
| 22 |
+
"single_word": false
|
| 23 |
+
}
|
| 24 |
+
}
|
tokenizer/tokenizer_config.json
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"add_bos_token": false,
|
| 3 |
+
"add_prefix_space": false,
|
| 4 |
+
"bos_token": {
|
| 5 |
+
"__type": "AddedToken",
|
| 6 |
+
"content": "<|endoftext|>",
|
| 7 |
+
"lstrip": false,
|
| 8 |
+
"normalized": true,
|
| 9 |
+
"rstrip": false,
|
| 10 |
+
"single_word": false
|
| 11 |
+
},
|
| 12 |
+
"clean_up_tokenization_spaces": true,
|
| 13 |
+
"eos_token": {
|
| 14 |
+
"__type": "AddedToken",
|
| 15 |
+
"content": "<|endoftext|>",
|
| 16 |
+
"lstrip": false,
|
| 17 |
+
"normalized": true,
|
| 18 |
+
"rstrip": false,
|
| 19 |
+
"single_word": false
|
| 20 |
+
},
|
| 21 |
+
"errors": "replace",
|
| 22 |
+
"model_max_length": 1024,
|
| 23 |
+
"pad_token": null,
|
| 24 |
+
"tokenizer_class": "GPT2Tokenizer",
|
| 25 |
+
"unk_token": {
|
| 26 |
+
"__type": "AddedToken",
|
| 27 |
+
"content": "<|endoftext|>",
|
| 28 |
+
"lstrip": false,
|
| 29 |
+
"normalized": true,
|
| 30 |
+
"rstrip": false,
|
| 31 |
+
"single_word": false
|
| 32 |
+
}
|
| 33 |
+
}
|
tokenizer/vocab.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|