File size: 8,570 Bytes
fdeb8e7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
from transformers import GPT2Tokenizer, GPT2LMHeadModel

class CustomTokenEncoderDecoder:
    CUSTOM_CLASSIFICATION_TOKEN = "which_genre_section"

    def __init__(self, events: [[int, int]], sections: [str], steps_per_section: int, model: GPT2LMHeadModel, tokenizer: GPT2Tokenizer):
        self.__model = model
        self.__tokenizer = tokenizer
        self.__events = events
        self.__steps_per_section = steps_per_section
        self.__sections = sections
        self.__events_tokens = self.events_to_tokens(events)
        
    def is_step_token(self, token: str) -> bool:
        return token.startswith("step:")

    def is_pitch_token(self, token: str) -> bool:
        return token.startswith("pitch:")

    def is_genre_token(self, token: str) -> bool:
        return token.startswith("genre:")

    def is_section_token(self, token: str) -> bool:
        return token.startswith("section:")

    def token_to_pitch(self, token: str) -> int:
        return int(token.split(":")[1])

    def token_to_step(self, token: str) -> int:
        return int(token.split(":")[1])

    def token_to_section(self, token: str) -> str:
        return token.split(":")[1]

    def token_to_genre(self, token: str) -> str:
        return token.split(":")[1]

    def pitch_to_token(self, pitch: int) -> str:
        return "pitch:{0}".format(pitch)

    def step_to_token(self, step: int) -> [str]:
        return "step:{0}".format(step)

    def section_to_token(self, section: str) -> [str]:
        return "section:{0}".format(section)

    def events_to_tokens(self, events: [[int, int]]) -> [str]:
        result: [str] = []
        
        for step_id in range(self.__steps_per_section):
            step_data = list(filter(lambda x: x[0] == step_id, events))
            
            if len(step_data) > 0:
                result.append(self.step_to_token(step_id))
                step_tokens = list(map(lambda x: self.pitch_to_token(x[1]), step_data))
                if len(step_tokens) > 0:
                    result += step_tokens
        
        return result

    def tokens_to_classification_prompt(self, tokens: [str]) -> str:
        return " ".join(tokens + [self.CUSTOM_CLASSIFICATION_TOKEN])
    
    def tokens_to_section_prompt(self, tokens: [str], section: str, prompted_section: str) -> str:
        return " ".join([self.section_to_token(section)] + tokens + [self.section_to_token(prompted_section)])

    def tokens_to_genre_section(self, tokens: [str]) -> dict:
        genre: str = ""
        section: str = ""
        
        for token in tokens:
            if self.is_genre_token(token):
                genre = self.token_to_genre(token)
            elif self.is_section_token(token):
                section = self.token_to_section(token)
                
        return { "genre": genre, "section": section }

    def section_to_step_offset(self, section: str) -> int:
        if section == "a":
            return 0
        elif section == "b":
            return self.__steps_per_section
        elif section == "c":
            return 2 * self.__steps_per_section
        elif section == "d":
            return 3 * self.__steps_per_section
        else:
            raise Exception("Invalid section: {0}".format(section))
        
    def tokens_to_section_events(self, tokens: [str], section: str, step_offset: int = None) -> [[int, int]]:
        for (token_id, token) in enumerate(tokens):
            if self.is_section_token(token):
                if self.token_to_section(token) == section:
                    offset: int = self.section_to_step_offset(section)
                    if step_offset is not None:
                        offset = step_offset
                    return self.tokens_to_events(tokens=tokens[token_id:], step_offset=offset)
        
        raise Exception("Section {0} not found in tokens".format(section))
        
    def tokens_to_events(self, tokens: [str], step_offset: int) -> [[int, int]]:
        result: [[int, int]] = []
        
        for (token_id, token) in enumerate(tokens):
            if self.is_step_token(token):
                step = self.token_to_step(token) + step_offset
                next_token_id = token_id + 1
                
                while next_token_id < len(tokens) and self.is_pitch_token(tokens[next_token_id]):
                    pitch = self.token_to_pitch(tokens[next_token_id])
                    result.append((step, pitch))
                    next_token_id += 1
                
        return result

    def convert_events_to_section_events(self, events: [[int, int]], section: str) -> [[int, int]]:
        offset = self.step_offset_for_section(section)
        return list(map(lambda x: (x[0] + offset, x[1]), events))
    
    def generate_events(self, temperature: float) -> dict:
        genre_section_data = self.make_classification_inference(temperature=temperature)
        genre = genre_section_data["genre"]
        section = genre_section_data["section"]
        print("Classification results")
        print("======================")
        print("Found genre: {0}".format(genre))
        print("Found section: {0}".format(section))
        print("======================")
        
        all_events: [[int, int]] = []
        
        all_events += list(map(lambda x: (x[0] + self.section_to_step_offset(section=section), x[1]) ,self.__events))
        
        if section not in self.__sections:
            raise Exception("Section {0} not found in sections".format(section))
        
        other_sections = list(filter(lambda x: x != section, self.__sections))
        for other_section in other_sections:
            prompt = self.tokens_to_section_prompt(tokens=self.__events_tokens, section=section, prompted_section=other_section)
            events = self.make_section_events_inference(prompt=prompt, temperature=temperature, section=other_section, known_section=section)
            all_events += events    
       
        return {
            "events": all_events,
            "genre": genre       
        }
        
    def tokens_to_genre_and_section_information(self, tokens: [str]) -> dict:
        genre: str = ""
        section: str = ""
        
        for token in tokens:
            if self.is_genre_token(token):
                genre = self.token_to_genre(token)
            elif self.is_section_token(token):
                section = self.token_to_section(token)
                

        return { "genre": genre, "section": section }
        
    def make_classification_inference(self, temperature: float) -> dict:
        genre_and_section_prompt = self.tokens_to_classification_prompt(self.__events_tokens)
        prompt = self.__tokenizer.encode(genre_and_section_prompt, add_special_tokens=True, return_tensors="pt")
        
        generated_section_genre_sequence = self.__model.generate(
            prompt,
            max_length=1024,
            do_sample=True,
            temperature=0.1,
            num_return_sequences=1,
        )
        
        section_genre_result = self.__tokenizer.decode(generated_section_genre_sequence[0], skip_special_tokens=True)
        assert len(section_genre_result) > 0, "Empty result"
        
        genre_section_data = self.tokens_to_genre_and_section_information(section_genre_result.split(" "))
        return genre_section_data
    
    def make_section_events_inference(self, prompt: str, section: str, temperature: float, known_section: str) -> [[int, int]]:
        tokenised_prompt = self.__tokenizer.encode(prompt, add_special_tokens=True, return_tensors="pt")
        assert len(tokenised_prompt[0]) <= 1024, "Prompt length exceeds maximum sequence length"
    
        generated_sequence = self.__model.generate(
                tokenised_prompt,
                max_length=1024,
                do_sample=True,
                temperature=temperature,
                num_return_sequences=1,
        )
        
        result = self.__tokenizer.decode(
                generated_sequence[0], skip_special_tokens=True
        )

        events = self.tokens_to_section_events(tokens=result.split(" "), section=section)
        # Fallback option when inference fails (sometimes the model generates a sequence that doesn't contain the section)
        if len(events) == 0:
            events = self.tokens_to_section_events(tokens=result.split(" "), section=known_section, step_offset=self.section_to_step_offset(section=section))
            
        assert len(events) > 0, "Empty result"
        
        return events