auto-beater / beatgenerator.py
Achillefs Sourlas
Changed the data format
ec67254
from midiutil import MIDIFile
import base64
from io import BytesIO
from transformers import GPT2LMHeadModel, GPT2Tokenizer
from customtokenencoderdecoder import CustomTokenEncoderDecoder
class BeatGenerator:
STEP_SIZE = 0.25
STEPS_PER_SEQUENCE = 32
def __init__(self, model: GPT2LMHeadModel, tokenizer: GPT2Tokenizer):
self.__model = model
self.__tokenizer = tokenizer
self.__sections = ["a", "b", "c", "d"]
def generate_beat_old(self, user_prompt: [[int]], temperature: float, tempo: float) -> [str, str]:
# pitches = [36, 38, 42]
pitches = [36, 38, 39, 42, 45, 46, 47, 49, 51]
assert len(user_prompt) == len(pitches), "User prompt length must be equal to the number of pitches"
user_events: [[int, int]] = []
for pitch_id, pitch in enumerate(pitches):
for step in user_prompt[pitch_id]:
user_events.append((step, pitch))
custom_token_encoder_decoder = CustomTokenEncoderDecoder(
events=user_events,
sections=self.__sections,
steps_per_section=self.STEPS_PER_SEQUENCE,
model=self.__model,
tokenizer=self.__tokenizer,
)
result = custom_token_encoder_decoder.generate_events(temperature=temperature)
genre = result["genre"]
events = result["events"]
midi_buffer = self.__make_midi_buffer(
data_container=events,
tempo=tempo,
verbose=False
)
midi_base64 = base64.b64encode(midi_buffer.read()).decode("utf-8")
return genre, midi_base64
def generate_beat(self, user_prompt: [[int]], temperature: float, tempo: float) -> [str, [[int]]]:
# pitches = [36, 38, 42]
pitches = [36, 38, 39, 42, 45, 46, 47, 49, 51]
assert len(user_prompt) == len(pitches), "User prompt length must be equal to the number of pitches"
user_events: [[int, int]] = []
for pitch_id, pitch in enumerate(pitches):
for step in user_prompt[pitch_id]:
user_events.append((step, pitch))
custom_token_encoder_decoder = CustomTokenEncoderDecoder(
events=user_events,
sections=self.__sections,
steps_per_section=self.STEPS_PER_SEQUENCE,
model=self.__model,
tokenizer=self.__tokenizer,
)
result = custom_token_encoder_decoder.generate_events(temperature=temperature)
genre = result["genre"]
events = result["events"]
return genre, self.__make_events(data_container=events, pitches=pitches)
def __make_midi_buffer(self, data_container: [(int, int)], tempo: int, verbose: bool = False) -> BytesIO:
track_count = 1
out_midi_file = MIDIFile(1)
out_midi_file.addTempo(0, 0, tempo)
for data in data_container:
step = data[0]
pitch = data[1]
velocity = 100
if verbose is True:
print("Processing: {0} in step range: {1}".format(data, step_ranges[section_id]))
if step >= 0 and step < 128 and pitch >= 0 and pitch < 128:
start_time = float(step) * self.STEP_SIZE
volume = int(velocity)
out_midi_file.addNote(
track=0,
channel=9,
pitch=pitch,
time=start_time,
duration=self.STEP_SIZE,
volume=volume
)
buffer = BytesIO()
out_midi_file.writeFile(buffer)
buffer.seek(0)
with open("out.mid", "wb") as output_file:
out_midi_file.writeFile(output_file)
return buffer
def __make_events(self, data_container: [(int, int)], pitches: [int], verbose: bool = False) -> [[int]]:
result: [[int]] = []
for pitch in pitches:
result.append([])
for data in data_container:
step = data[0]
pitch = data[1]
velocity = 100
if verbose is True:
print("Processing: {0} in step range: {1}".format(data, step_ranges[section_id]))
if step >= 0 and step < 128 and pitch >= 0 and pitch < 128:
# find the index of the pitch in the pitches array
pitch_index = pitches.index(pitch)
result[pitch_index].append(step)
return result