from transformers import GPT2Tokenizer, GPT2LMHeadModel class CustomTokenEncoderDecoder: CUSTOM_CLASSIFICATION_TOKEN = "which_genre_section" def __init__(self, events: [[int, int]], sections: [str], steps_per_section: int, model: GPT2LMHeadModel, tokenizer: GPT2Tokenizer): self.__model = model self.__tokenizer = tokenizer self.__events = events self.__steps_per_section = steps_per_section self.__sections = sections self.__events_tokens = self.events_to_tokens(events) def is_step_token(self, token: str) -> bool: return token.startswith("step:") def is_pitch_token(self, token: str) -> bool: return token.startswith("pitch:") def is_genre_token(self, token: str) -> bool: return token.startswith("genre:") def is_section_token(self, token: str) -> bool: return token.startswith("section:") def token_to_pitch(self, token: str) -> int: return int(token.split(":")[1]) def token_to_step(self, token: str) -> int: return int(token.split(":")[1]) def token_to_section(self, token: str) -> str: return token.split(":")[1] def token_to_genre(self, token: str) -> str: return token.split(":")[1] def pitch_to_token(self, pitch: int) -> str: return "pitch:{0}".format(pitch) def step_to_token(self, step: int) -> [str]: return "step:{0}".format(step) def section_to_token(self, section: str) -> [str]: return "section:{0}".format(section) def events_to_tokens(self, events: [[int, int]]) -> [str]: result: [str] = [] for step_id in range(self.__steps_per_section): step_data = list(filter(lambda x: x[0] == step_id, events)) if len(step_data) > 0: result.append(self.step_to_token(step_id)) step_tokens = list(map(lambda x: self.pitch_to_token(x[1]), step_data)) if len(step_tokens) > 0: result += step_tokens return result def tokens_to_classification_prompt(self, tokens: [str]) -> str: return " ".join(tokens + [self.CUSTOM_CLASSIFICATION_TOKEN]) def tokens_to_section_prompt(self, tokens: [str], section: str, prompted_section: str) -> str: return " ".join([self.section_to_token(section)] + tokens + [self.section_to_token(prompted_section)]) def tokens_to_genre_section(self, tokens: [str]) -> dict: genre: str = "" section: str = "" for token in tokens: if self.is_genre_token(token): genre = self.token_to_genre(token) elif self.is_section_token(token): section = self.token_to_section(token) return { "genre": genre, "section": section } def section_to_step_offset(self, section: str) -> int: if section == "a": return 0 elif section == "b": return self.__steps_per_section elif section == "c": return 2 * self.__steps_per_section elif section == "d": return 3 * self.__steps_per_section else: raise Exception("Invalid section: {0}".format(section)) def tokens_to_section_events(self, tokens: [str], section: str, step_offset: int = None) -> [[int, int]]: for (token_id, token) in enumerate(tokens): if self.is_section_token(token): if self.token_to_section(token) == section: offset: int = self.section_to_step_offset(section) if step_offset is not None: offset = step_offset return self.tokens_to_events(tokens=tokens[token_id:], step_offset=offset) raise Exception("Section {0} not found in tokens".format(section)) def tokens_to_events(self, tokens: [str], step_offset: int) -> [[int, int]]: result: [[int, int]] = [] for (token_id, token) in enumerate(tokens): if self.is_step_token(token): step = self.token_to_step(token) + step_offset next_token_id = token_id + 1 while next_token_id < len(tokens) and self.is_pitch_token(tokens[next_token_id]): pitch = self.token_to_pitch(tokens[next_token_id]) result.append((step, pitch)) next_token_id += 1 return result def convert_events_to_section_events(self, events: [[int, int]], section: str) -> [[int, int]]: offset = self.step_offset_for_section(section) return list(map(lambda x: (x[0] + offset, x[1]), events)) def generate_events(self, temperature: float) -> dict: genre_section_data = self.make_classification_inference(temperature=temperature) genre = genre_section_data["genre"] section = genre_section_data["section"] print("Classification results") print("======================") print("Found genre: {0}".format(genre)) print("Found section: {0}".format(section)) print("======================") all_events: [[int, int]] = [] all_events += list(map(lambda x: (x[0] + self.section_to_step_offset(section=section), x[1]) ,self.__events)) if section not in self.__sections: raise Exception("Section {0} not found in sections".format(section)) other_sections = list(filter(lambda x: x != section, self.__sections)) for other_section in other_sections: prompt = self.tokens_to_section_prompt(tokens=self.__events_tokens, section=section, prompted_section=other_section) events = self.make_section_events_inference(prompt=prompt, temperature=temperature, section=other_section, known_section=section) all_events += events return { "events": all_events, "genre": genre } def tokens_to_genre_and_section_information(self, tokens: [str]) -> dict: genre: str = "" section: str = "" for token in tokens: if self.is_genre_token(token): genre = self.token_to_genre(token) elif self.is_section_token(token): section = self.token_to_section(token) return { "genre": genre, "section": section } def make_classification_inference(self, temperature: float) -> dict: genre_and_section_prompt = self.tokens_to_classification_prompt(self.__events_tokens) prompt = self.__tokenizer.encode(genre_and_section_prompt, add_special_tokens=True, return_tensors="pt") generated_section_genre_sequence = self.__model.generate( prompt, max_length=1024, do_sample=True, temperature=0.1, num_return_sequences=1, ) section_genre_result = self.__tokenizer.decode(generated_section_genre_sequence[0], skip_special_tokens=True) assert len(section_genre_result) > 0, "Empty result" genre_section_data = self.tokens_to_genre_and_section_information(section_genre_result.split(" ")) return genre_section_data def make_section_events_inference(self, prompt: str, section: str, temperature: float, known_section: str) -> [[int, int]]: tokenised_prompt = self.__tokenizer.encode(prompt, add_special_tokens=True, return_tensors="pt") assert len(tokenised_prompt[0]) <= 1024, "Prompt length exceeds maximum sequence length" generated_sequence = self.__model.generate( tokenised_prompt, max_length=1024, do_sample=True, temperature=temperature, num_return_sequences=1, ) result = self.__tokenizer.decode( generated_sequence[0], skip_special_tokens=True ) events = self.tokens_to_section_events(tokens=result.split(" "), section=section) # Fallback option when inference fails (sometimes the model generates a sequence that doesn't contain the section) if len(events) == 0: events = self.tokens_to_section_events(tokens=result.split(" "), section=known_section, step_offset=self.section_to_step_offset(section=section)) assert len(events) > 0, "Empty result" return events