simple-beat-generator / customtokenencoderdecoder.py
Achillefs Sourlas
- Replaced the model with another fine-tuned GPT-2 with custom tokens.
fdeb8e7
from transformers import GPT2Tokenizer, GPT2LMHeadModel
class CustomTokenEncoderDecoder:
CUSTOM_CLASSIFICATION_TOKEN = "which_genre_section"
def __init__(self, events: [[int, int]], sections: [str], steps_per_section: int, model: GPT2LMHeadModel, tokenizer: GPT2Tokenizer):
self.__model = model
self.__tokenizer = tokenizer
self.__events = events
self.__steps_per_section = steps_per_section
self.__sections = sections
self.__events_tokens = self.events_to_tokens(events)
def is_step_token(self, token: str) -> bool:
return token.startswith("step:")
def is_pitch_token(self, token: str) -> bool:
return token.startswith("pitch:")
def is_genre_token(self, token: str) -> bool:
return token.startswith("genre:")
def is_section_token(self, token: str) -> bool:
return token.startswith("section:")
def token_to_pitch(self, token: str) -> int:
return int(token.split(":")[1])
def token_to_step(self, token: str) -> int:
return int(token.split(":")[1])
def token_to_section(self, token: str) -> str:
return token.split(":")[1]
def token_to_genre(self, token: str) -> str:
return token.split(":")[1]
def pitch_to_token(self, pitch: int) -> str:
return "pitch:{0}".format(pitch)
def step_to_token(self, step: int) -> [str]:
return "step:{0}".format(step)
def section_to_token(self, section: str) -> [str]:
return "section:{0}".format(section)
def events_to_tokens(self, events: [[int, int]]) -> [str]:
result: [str] = []
for step_id in range(self.__steps_per_section):
step_data = list(filter(lambda x: x[0] == step_id, events))
if len(step_data) > 0:
result.append(self.step_to_token(step_id))
step_tokens = list(map(lambda x: self.pitch_to_token(x[1]), step_data))
if len(step_tokens) > 0:
result += step_tokens
return result
def tokens_to_classification_prompt(self, tokens: [str]) -> str:
return " ".join(tokens + [self.CUSTOM_CLASSIFICATION_TOKEN])
def tokens_to_section_prompt(self, tokens: [str], section: str, prompted_section: str) -> str:
return " ".join([self.section_to_token(section)] + tokens + [self.section_to_token(prompted_section)])
def tokens_to_genre_section(self, tokens: [str]) -> dict:
genre: str = ""
section: str = ""
for token in tokens:
if self.is_genre_token(token):
genre = self.token_to_genre(token)
elif self.is_section_token(token):
section = self.token_to_section(token)
return { "genre": genre, "section": section }
def section_to_step_offset(self, section: str) -> int:
if section == "a":
return 0
elif section == "b":
return self.__steps_per_section
elif section == "c":
return 2 * self.__steps_per_section
elif section == "d":
return 3 * self.__steps_per_section
else:
raise Exception("Invalid section: {0}".format(section))
def tokens_to_section_events(self, tokens: [str], section: str, step_offset: int = None) -> [[int, int]]:
for (token_id, token) in enumerate(tokens):
if self.is_section_token(token):
if self.token_to_section(token) == section:
offset: int = self.section_to_step_offset(section)
if step_offset is not None:
offset = step_offset
return self.tokens_to_events(tokens=tokens[token_id:], step_offset=offset)
raise Exception("Section {0} not found in tokens".format(section))
def tokens_to_events(self, tokens: [str], step_offset: int) -> [[int, int]]:
result: [[int, int]] = []
for (token_id, token) in enumerate(tokens):
if self.is_step_token(token):
step = self.token_to_step(token) + step_offset
next_token_id = token_id + 1
while next_token_id < len(tokens) and self.is_pitch_token(tokens[next_token_id]):
pitch = self.token_to_pitch(tokens[next_token_id])
result.append((step, pitch))
next_token_id += 1
return result
def convert_events_to_section_events(self, events: [[int, int]], section: str) -> [[int, int]]:
offset = self.step_offset_for_section(section)
return list(map(lambda x: (x[0] + offset, x[1]), events))
def generate_events(self, temperature: float) -> dict:
genre_section_data = self.make_classification_inference(temperature=temperature)
genre = genre_section_data["genre"]
section = genre_section_data["section"]
print("Classification results")
print("======================")
print("Found genre: {0}".format(genre))
print("Found section: {0}".format(section))
print("======================")
all_events: [[int, int]] = []
all_events += list(map(lambda x: (x[0] + self.section_to_step_offset(section=section), x[1]) ,self.__events))
if section not in self.__sections:
raise Exception("Section {0} not found in sections".format(section))
other_sections = list(filter(lambda x: x != section, self.__sections))
for other_section in other_sections:
prompt = self.tokens_to_section_prompt(tokens=self.__events_tokens, section=section, prompted_section=other_section)
events = self.make_section_events_inference(prompt=prompt, temperature=temperature, section=other_section, known_section=section)
all_events += events
return {
"events": all_events,
"genre": genre
}
def tokens_to_genre_and_section_information(self, tokens: [str]) -> dict:
genre: str = ""
section: str = ""
for token in tokens:
if self.is_genre_token(token):
genre = self.token_to_genre(token)
elif self.is_section_token(token):
section = self.token_to_section(token)
return { "genre": genre, "section": section }
def make_classification_inference(self, temperature: float) -> dict:
genre_and_section_prompt = self.tokens_to_classification_prompt(self.__events_tokens)
prompt = self.__tokenizer.encode(genre_and_section_prompt, add_special_tokens=True, return_tensors="pt")
generated_section_genre_sequence = self.__model.generate(
prompt,
max_length=1024,
do_sample=True,
temperature=0.1,
num_return_sequences=1,
)
section_genre_result = self.__tokenizer.decode(generated_section_genre_sequence[0], skip_special_tokens=True)
assert len(section_genre_result) > 0, "Empty result"
genre_section_data = self.tokens_to_genre_and_section_information(section_genre_result.split(" "))
return genre_section_data
def make_section_events_inference(self, prompt: str, section: str, temperature: float, known_section: str) -> [[int, int]]:
tokenised_prompt = self.__tokenizer.encode(prompt, add_special_tokens=True, return_tensors="pt")
assert len(tokenised_prompt[0]) <= 1024, "Prompt length exceeds maximum sequence length"
generated_sequence = self.__model.generate(
tokenised_prompt,
max_length=1024,
do_sample=True,
temperature=temperature,
num_return_sequences=1,
)
result = self.__tokenizer.decode(
generated_sequence[0], skip_special_tokens=True
)
events = self.tokens_to_section_events(tokens=result.split(" "), section=section)
# Fallback option when inference fails (sometimes the model generates a sequence that doesn't contain the section)
if len(events) == 0:
events = self.tokens_to_section_events(tokens=result.split(" "), section=known_section, step_offset=self.section_to_step_offset(section=section))
assert len(events) > 0, "Empty result"
return events