Achillefs Sourlas commited on
Commit
fdeb8e7
·
1 Parent(s): 307a202

- Replaced the model with another fine-tuned GPT-2 with custom tokens.

Browse files
.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ venv/
2
+ __pycache__/
app.py CHANGED
@@ -1,24 +1,31 @@
1
  import gradio as gr
2
  from beatgenerator import BeatGenerator
3
  from datetime import datetime
 
4
 
5
  STEP_COUNT = 32
6
- INSTRUMENT_COUNT = 3
7
 
8
- def on_submit(*grid_rows) -> str:
 
 
 
 
 
9
  step_data_container = []
10
 
11
  for grid_row_id in range(INSTRUMENT_COUNT):
12
  grid_row_as_ints = list(map(lambda x: int(x) - 1, grid_rows[grid_row_id]))
13
  step_data_container.append(grid_row_as_ints)
14
 
15
- beat_generator = BeatGenerator(data_container=step_data_container, temperature=grid_rows[3])
16
- genre, beat_data = beat_generator.make_beat()
17
-
18
  now = datetime.now()
19
  date_string = now.strftime("%Y-%m-%d_%H-%M")
 
 
20
 
21
- return """<div><h3>Genre: {0}</h3></div><br/><div><a href="data:audio/midi;base64,{1}" download="beat-{0}-{2}.mid">Download beat</a></div>""".format(genre, beat_data, date_string)
22
 
23
  checkbox_rows = [
24
  ["{:02d}".format(col + 1) for col in range(STEP_COUNT)] for _ in range(INSTRUMENT_COUNT)
@@ -27,8 +34,15 @@ checkbox_rows = [
27
  inputs = [
28
  gr.inputs.CheckboxGroup(checkbox_rows[0], label=f"Kick"),
29
  gr.inputs.CheckboxGroup(checkbox_rows[1], label=f"Snare"),
 
30
  gr.inputs.CheckboxGroup(checkbox_rows[2], label=f"Hat"),
31
- gr.inputs.Slider(minimum=0.1, maximum=1.0, step=0.1, default=0.7, label="Temperature")
 
 
 
 
 
 
32
  ]
33
 
34
  iface = gr.Interface(
@@ -36,7 +50,7 @@ iface = gr.Interface(
36
  inputs=inputs,
37
  outputs=["html"],
38
  title="Simple (MIDI) Beat Generator",
39
- description="A simple beat generator that creates an 8-bar MIDI beats on every run, based on a 32-step (2 bars) prompt in the form of a step sequencer with three instruments: kick, snare and hihat. The generator uses a small fine-tuned GPT-2 model to recognise the genre (currently only Trap and Deep House) and generate the beat. The result is a mixture between the original prompt and the generated material. Higher temperature values may introduce more instruments to the generated beat."
40
  )
41
 
42
  iface.launch()
 
1
  import gradio as gr
2
  from beatgenerator import BeatGenerator
3
  from datetime import datetime
4
+ from transformers import GPT2LMHeadModel, GPT2Tokenizer
5
 
6
  STEP_COUNT = 32
7
+ INSTRUMENT_COUNT = 9
8
 
9
+ model = GPT2LMHeadModel.from_pretrained("./model")
10
+ tokenizer = GPT2Tokenizer.from_pretrained("./tokenizer")
11
+ tokenizer.pad_token = tokenizer.eos_token
12
+ beat_generator = BeatGenerator(model=model, tokenizer=tokenizer)
13
+
14
+ def on_submit(*grid_rows) -> [str]:
15
  step_data_container = []
16
 
17
  for grid_row_id in range(INSTRUMENT_COUNT):
18
  grid_row_as_ints = list(map(lambda x: int(x) - 1, grid_rows[grid_row_id]))
19
  step_data_container.append(grid_row_as_ints)
20
 
21
+ temperature: float = grid_rows[9]
22
+ tempo: int = grid_rows[10]
 
23
  now = datetime.now()
24
  date_string = now.strftime("%Y-%m-%d_%H-%M")
25
+
26
+ genre, midi_data = beat_generator.generate_beat(user_prompt=step_data_container, temperature=temperature, tempo=tempo)
27
 
28
+ return ["""<div><h3>Genre: {0}</h3></div><br/><div><a href="data:audio/midi;base64,{1}" download="beat-{0}-{2}.mid">Download beat</a></div>""".format(genre, midi_data, date_string)]
29
 
30
  checkbox_rows = [
31
  ["{:02d}".format(col + 1) for col in range(STEP_COUNT)] for _ in range(INSTRUMENT_COUNT)
 
34
  inputs = [
35
  gr.inputs.CheckboxGroup(checkbox_rows[0], label=f"Kick"),
36
  gr.inputs.CheckboxGroup(checkbox_rows[1], label=f"Snare"),
37
+ gr.inputs.CheckboxGroup(checkbox_rows[2], label=f"Clap"),
38
  gr.inputs.CheckboxGroup(checkbox_rows[2], label=f"Hat"),
39
+ gr.inputs.CheckboxGroup(checkbox_rows[2], label=f"L tom"),
40
+ gr.inputs.CheckboxGroup(checkbox_rows[2], label=f"Open hat"),
41
+ gr.inputs.CheckboxGroup(checkbox_rows[2], label=f"M tom"),
42
+ gr.inputs.CheckboxGroup(checkbox_rows[2], label=f"Crash cymbal"),
43
+ gr.inputs.CheckboxGroup(checkbox_rows[2], label=f"Ride cymbal"),
44
+ gr.inputs.Slider(minimum=0.1, maximum=1.0, step=0.1, default=0.7, label="Temperature"),
45
+ gr.inputs.Slider(minimum=60, maximum=200, step=1, default=120, label="Tempo")
46
  ]
47
 
48
  iface = gr.Interface(
 
50
  inputs=inputs,
51
  outputs=["html"],
52
  title="Simple (MIDI) Beat Generator",
53
+ description="A simple beat generator that creates an 8-bar MIDI beats on every run, based on a 32-step (2 bars) prompt in the form of a step sequencer. The generator uses a small fine-tuned GPT-2 model to recognise the genre (currently only Trap and Deep House) and generate the beat."
54
  )
55
 
56
  iface.launch()
beatgenerator.py CHANGED
@@ -2,175 +2,54 @@ from midiutil import MIDIFile
2
  import base64
3
  from io import BytesIO
4
  from transformers import GPT2LMHeadModel, GPT2Tokenizer
5
- import random
6
 
7
  class BeatGenerator:
8
  STEP_SIZE = 0.25
 
9
 
10
- def __init__(self, data_container: [[int]], temperature: [float], model: str = "gpt2"):
11
- self.__data_container = data_container
12
- self.__temperature = temperature
13
- self.__model = GPT2LMHeadModel.from_pretrained("./model")
14
- tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
15
- tokenizer.pad_token = tokenizer.eos_token
16
  self.__tokenizer = tokenizer
17
- self.__model_type = model
18
- self.__promptAsMusicEvents: [[int, int]] = []
19
 
20
- # Public methods
21
- def make_beat(self) -> [str, str]:
22
- genre_prompt = self.__make_genre_prompt()
23
- genre = self.__make_genre_inference_gpt_2(prompt=genre_prompt, temperature=self.__temperature)
24
- beat_prompt = self.__make_beat_prompt(genre=genre)
25
- sequence = self.__make_inference_gpt_2(prompt=beat_prompt, temperature=self.__temperature)
26
- music_events_container = self.__make_music_events(sequence=sequence)
27
- composite_events = self.__make_composite_events(events=music_events_container)
28
-
29
- midi_buffer = self.__make_midi_buffer(
30
- data_container=composite_events,
31
- verbose=False
32
- )
33
- midi_base64 = base64.b64encode(midi_buffer.read()).decode("utf-8")
34
 
35
- return genre, midi_base64
36
-
37
- # Private methods
38
- def __make_genre_prompt(self) -> str:
39
- pitches: [int] = [36, 38, 42]
40
- prompt_prefix = "Seed "
41
- prompt: str = ""
42
- prompt_suffix = "Genre"
43
-
44
- for instrument_id, instrument_steps in enumerate(self.__data_container):
45
- for step in instrument_steps:
46
- prompt += "({0},{1}) ".format(step, pitches[instrument_id])
47
- self.__promptAsMusicEvents.append((step, pitches[instrument_id]))
48
-
49
- return prompt_prefix + prompt + prompt_suffix
50
-
51
- def __make_beat_prompt(self, genre: str) -> str:
52
- pitches: [int] = [36, 38, 42]
53
- prompt_prefix = "Seed "
54
- prompt: str = ""
55
- prompt_suffix = genre + " Beat"
56
-
57
- for instrument_id, instrument_steps in enumerate(self.__data_container):
58
- for step in instrument_steps:
59
- prompt += "({0},{1}) ".format(step, pitches[instrument_id])
60
- self.__promptAsMusicEvents.append((step, pitches[instrument_id]))
61
-
62
- return prompt_prefix + prompt + prompt_suffix
63
-
64
- def __make_music_events(self, sequence: [str]) -> [(int, int)]:
65
- minimum_event_length = 5
66
- result: [(int, int)] = []
67
- sequence_list = sequence.split(" ")
68
-
69
- if len(sequence_list) == 0:
70
- return result
71
-
72
- for i in range(len(sequence_list)):
73
- if len(sequence_list[i]) >= minimum_event_length and sequence_list[i][0] == '(' and sequence_list[i][-1] == ')':
74
- step_pitch = sequence_list[i][1:-1].split(",")
75
-
76
- if len(step_pitch) == 2:
77
- isValid = True
78
-
79
- for item in step_pitch:
80
- if not item.isdigit():
81
- isValid = False
82
- break
83
-
84
- if isValid:
85
- result.append((int(step_pitch[0]), int(step_pitch[1])))
86
-
87
- return result
88
-
89
- def __make_inference(self, prompt: str, temperature: float) -> str:
90
- if self.__model_type == "gpt2":
91
- return self.__make_inference_gpt_2(prompt=prompt, temperature=temperature)
92
- else:
93
- raise Exception("Invalid model")
94
-
95
- def __make_genre_inference_gpt_2(self, prompt: str, temperature: float) -> str:
96
- print("Generating GPT-2 sequence... with temperature: {0}".format(temperature))
97
- tokenised_prompt = self.__tokenizer.encode(prompt, add_special_tokens=True, return_tensors="pt")
98
- generated_sequence = self.__model.generate(
99
- tokenised_prompt,
100
- max_length=1024,
101
- do_sample=True,
102
- temperature=0.5, #temperature,
103
- num_return_sequences=1,
104
- )
105
 
106
- result = self.__tokenizer.decode(
107
- generated_sequence[0], skip_special_tokens=True
108
- )
109
-
110
- result_as_list = result.split(" ")
111
-
112
- return self.__find_next_element("Genre", result_as_list)
113
-
114
- def __make_inference_gpt_2(self, prompt: str, temperature: float) -> str:
115
- print("Generating GPT-2 sequence... with temperature: {0}".format(temperature))
116
- tokenised_prompt = self.__tokenizer.encode(prompt, add_special_tokens=True, return_tensors="pt")
117
- generated_sequence = self.__model.generate(
118
- tokenised_prompt,
119
- max_length=1024,
120
- do_sample=True,
121
- temperature=temperature,
122
- num_return_sequences=1,
123
  )
124
 
125
- result = self.__tokenizer.decode(
126
- generated_sequence[0], skip_special_tokens=True
127
- )
128
-
129
- return result
130
-
131
- def __matches_probability(self, p: float) -> bool:
132
- if p < 0 or p > 1:
133
- raise ValueError("Probability should be a number between 0 and 1.")
134
- random_number = random.random()
135
-
136
- return random_number <= p
137
-
138
- def __make_composite_events(self, events: [(int, int)]) -> [(int, int)]:
139
- result: [(int, int)] = []
140
- kick_probability = 0.95
141
- snare_probability = 0.85
142
- hihat_probability = 0.7
143
- steps_per_time = 32
144
- user_events_for_the_whole_beat: [(int, int)] = []
145
 
146
- for user_event in self.__promptAsMusicEvents:
147
- for times in range(4):
148
- user_events_for_the_whole_beat.append((user_event[0] + times * steps_per_time, user_event[1]))
149
 
150
- unique_user_events = list(filter(lambda x: x not in events, user_events_for_the_whole_beat))
151
-
152
- for unique_user_event in unique_user_events:
153
- probability = 0
154
- if unique_user_event[1] == 36:
155
- probability = kick_probability
156
- elif unique_user_event[1] == 38:
157
- probability = snare_probability
158
- elif unique_user_event[1] == 42:
159
- probability = hihat_probability
160
-
161
- if self.__matches_probability(p=probability) == True:
162
- result.append(unique_user_event)
163
 
164
- for event in events:
165
- result.append(event)
166
-
167
- return result
168
 
169
- def __make_midi_buffer(self, data_container: [(int, int)], verbose: bool = False) -> BytesIO:
170
  track_count = 1
171
  out_midi_file = MIDIFile(1)
172
- out_midi_file.addTempo(0, 0, 120)
173
-
174
  for data in data_container:
175
  step = data[0]
176
  pitch = data[1]
@@ -188,18 +67,15 @@ class BeatGenerator:
188
  channel=9,
189
  pitch=pitch,
190
  time=start_time,
191
- duration=0.25,
192
  volume=volume
193
  )
194
 
195
  buffer = BytesIO()
196
  out_midi_file.writeFile(buffer)
197
  buffer.seek(0)
 
 
 
198
 
199
- return buffer
200
-
201
- def __find_next_element(self, target, string_list) -> str:
202
- for i, string in enumerate(string_list[:-1]):
203
- if string == target:
204
- return string_list[i + 1]
205
- return "" # If the target string is not found or is the last element in the list
 
2
  import base64
3
  from io import BytesIO
4
  from transformers import GPT2LMHeadModel, GPT2Tokenizer
5
+ from customtokenencoderdecoder import CustomTokenEncoderDecoder
6
 
7
  class BeatGenerator:
8
  STEP_SIZE = 0.25
9
+ STEPS_PER_SEQUENCE = 32
10
 
11
+ def __init__(self, model: GPT2LMHeadModel, tokenizer: GPT2Tokenizer):
12
+ self.__model = model
 
 
 
 
13
  self.__tokenizer = tokenizer
14
+ self.__sections = ["a", "b", "c", "d"]
 
15
 
16
+ def generate_beat(self, user_prompt: [[int]], temperature: float, tempo: float) -> [str, str]:
17
+ # pitches = [36, 38, 42]
18
+ pitches = [36, 38, 39, 42, 45, 46, 47, 49, 51]
19
+ assert len(user_prompt) == len(pitches), "User prompt length must be equal to the number of pitches"
 
 
 
 
 
 
 
 
 
 
20
 
21
+ user_events: [[int, int]] = []
22
+ for pitch_id, pitch in enumerate(pitches):
23
+ for step in user_prompt[pitch_id]:
24
+ user_events.append((step, pitch))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
 
26
+ custom_token_encoder_decoder = CustomTokenEncoderDecoder(
27
+ events=user_events,
28
+ sections=self.__sections,
29
+ steps_per_section=self.STEPS_PER_SEQUENCE,
30
+ model=self.__model,
31
+ tokenizer=self.__tokenizer,
 
 
 
 
 
 
 
 
 
 
 
32
  )
33
 
34
+ result = custom_token_encoder_decoder.generate_events(temperature=temperature)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
 
36
+ genre = result["genre"]
37
+ events = result["events"]
 
38
 
39
+ midi_buffer = self.__make_midi_buffer(
40
+ data_container=events,
41
+ tempo=tempo,
42
+ verbose=False
43
+ )
44
+ midi_base64 = base64.b64encode(midi_buffer.read()).decode("utf-8")
 
 
 
 
 
 
 
45
 
46
+ return genre, midi_base64
 
 
 
47
 
48
+ def __make_midi_buffer(self, data_container: [(int, int)], tempo: int, verbose: bool = False) -> BytesIO:
49
  track_count = 1
50
  out_midi_file = MIDIFile(1)
51
+ out_midi_file.addTempo(0, 0, tempo)
52
+
53
  for data in data_container:
54
  step = data[0]
55
  pitch = data[1]
 
67
  channel=9,
68
  pitch=pitch,
69
  time=start_time,
70
+ duration=self.STEP_SIZE,
71
  volume=volume
72
  )
73
 
74
  buffer = BytesIO()
75
  out_midi_file.writeFile(buffer)
76
  buffer.seek(0)
77
+
78
+ with open("out.mid", "wb") as output_file:
79
+ out_midi_file.writeFile(output_file)
80
 
81
+ return buffer
 
 
 
 
 
 
customtokenencoderdecoder.py ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import GPT2Tokenizer, GPT2LMHeadModel
2
+
3
+ class CustomTokenEncoderDecoder:
4
+ CUSTOM_CLASSIFICATION_TOKEN = "which_genre_section"
5
+
6
+ def __init__(self, events: [[int, int]], sections: [str], steps_per_section: int, model: GPT2LMHeadModel, tokenizer: GPT2Tokenizer):
7
+ self.__model = model
8
+ self.__tokenizer = tokenizer
9
+ self.__events = events
10
+ self.__steps_per_section = steps_per_section
11
+ self.__sections = sections
12
+ self.__events_tokens = self.events_to_tokens(events)
13
+
14
+ def is_step_token(self, token: str) -> bool:
15
+ return token.startswith("step:")
16
+
17
+ def is_pitch_token(self, token: str) -> bool:
18
+ return token.startswith("pitch:")
19
+
20
+ def is_genre_token(self, token: str) -> bool:
21
+ return token.startswith("genre:")
22
+
23
+ def is_section_token(self, token: str) -> bool:
24
+ return token.startswith("section:")
25
+
26
+ def token_to_pitch(self, token: str) -> int:
27
+ return int(token.split(":")[1])
28
+
29
+ def token_to_step(self, token: str) -> int:
30
+ return int(token.split(":")[1])
31
+
32
+ def token_to_section(self, token: str) -> str:
33
+ return token.split(":")[1]
34
+
35
+ def token_to_genre(self, token: str) -> str:
36
+ return token.split(":")[1]
37
+
38
+ def pitch_to_token(self, pitch: int) -> str:
39
+ return "pitch:{0}".format(pitch)
40
+
41
+ def step_to_token(self, step: int) -> [str]:
42
+ return "step:{0}".format(step)
43
+
44
+ def section_to_token(self, section: str) -> [str]:
45
+ return "section:{0}".format(section)
46
+
47
+ def events_to_tokens(self, events: [[int, int]]) -> [str]:
48
+ result: [str] = []
49
+
50
+ for step_id in range(self.__steps_per_section):
51
+ step_data = list(filter(lambda x: x[0] == step_id, events))
52
+
53
+ if len(step_data) > 0:
54
+ result.append(self.step_to_token(step_id))
55
+ step_tokens = list(map(lambda x: self.pitch_to_token(x[1]), step_data))
56
+ if len(step_tokens) > 0:
57
+ result += step_tokens
58
+
59
+ return result
60
+
61
+ def tokens_to_classification_prompt(self, tokens: [str]) -> str:
62
+ return " ".join(tokens + [self.CUSTOM_CLASSIFICATION_TOKEN])
63
+
64
+ def tokens_to_section_prompt(self, tokens: [str], section: str, prompted_section: str) -> str:
65
+ return " ".join([self.section_to_token(section)] + tokens + [self.section_to_token(prompted_section)])
66
+
67
+ def tokens_to_genre_section(self, tokens: [str]) -> dict:
68
+ genre: str = ""
69
+ section: str = ""
70
+
71
+ for token in tokens:
72
+ if self.is_genre_token(token):
73
+ genre = self.token_to_genre(token)
74
+ elif self.is_section_token(token):
75
+ section = self.token_to_section(token)
76
+
77
+ return { "genre": genre, "section": section }
78
+
79
+ def section_to_step_offset(self, section: str) -> int:
80
+ if section == "a":
81
+ return 0
82
+ elif section == "b":
83
+ return self.__steps_per_section
84
+ elif section == "c":
85
+ return 2 * self.__steps_per_section
86
+ elif section == "d":
87
+ return 3 * self.__steps_per_section
88
+ else:
89
+ raise Exception("Invalid section: {0}".format(section))
90
+
91
+ def tokens_to_section_events(self, tokens: [str], section: str, step_offset: int = None) -> [[int, int]]:
92
+ for (token_id, token) in enumerate(tokens):
93
+ if self.is_section_token(token):
94
+ if self.token_to_section(token) == section:
95
+ offset: int = self.section_to_step_offset(section)
96
+ if step_offset is not None:
97
+ offset = step_offset
98
+ return self.tokens_to_events(tokens=tokens[token_id:], step_offset=offset)
99
+
100
+ raise Exception("Section {0} not found in tokens".format(section))
101
+
102
+ def tokens_to_events(self, tokens: [str], step_offset: int) -> [[int, int]]:
103
+ result: [[int, int]] = []
104
+
105
+ for (token_id, token) in enumerate(tokens):
106
+ if self.is_step_token(token):
107
+ step = self.token_to_step(token) + step_offset
108
+ next_token_id = token_id + 1
109
+
110
+ while next_token_id < len(tokens) and self.is_pitch_token(tokens[next_token_id]):
111
+ pitch = self.token_to_pitch(tokens[next_token_id])
112
+ result.append((step, pitch))
113
+ next_token_id += 1
114
+
115
+ return result
116
+
117
+ def convert_events_to_section_events(self, events: [[int, int]], section: str) -> [[int, int]]:
118
+ offset = self.step_offset_for_section(section)
119
+ return list(map(lambda x: (x[0] + offset, x[1]), events))
120
+
121
+ def generate_events(self, temperature: float) -> dict:
122
+ genre_section_data = self.make_classification_inference(temperature=temperature)
123
+ genre = genre_section_data["genre"]
124
+ section = genre_section_data["section"]
125
+ print("Classification results")
126
+ print("======================")
127
+ print("Found genre: {0}".format(genre))
128
+ print("Found section: {0}".format(section))
129
+ print("======================")
130
+
131
+ all_events: [[int, int]] = []
132
+
133
+ all_events += list(map(lambda x: (x[0] + self.section_to_step_offset(section=section), x[1]) ,self.__events))
134
+
135
+ if section not in self.__sections:
136
+ raise Exception("Section {0} not found in sections".format(section))
137
+
138
+ other_sections = list(filter(lambda x: x != section, self.__sections))
139
+ for other_section in other_sections:
140
+ prompt = self.tokens_to_section_prompt(tokens=self.__events_tokens, section=section, prompted_section=other_section)
141
+ events = self.make_section_events_inference(prompt=prompt, temperature=temperature, section=other_section, known_section=section)
142
+ all_events += events
143
+
144
+ return {
145
+ "events": all_events,
146
+ "genre": genre
147
+ }
148
+
149
+ def tokens_to_genre_and_section_information(self, tokens: [str]) -> dict:
150
+ genre: str = ""
151
+ section: str = ""
152
+
153
+ for token in tokens:
154
+ if self.is_genre_token(token):
155
+ genre = self.token_to_genre(token)
156
+ elif self.is_section_token(token):
157
+ section = self.token_to_section(token)
158
+
159
+
160
+ return { "genre": genre, "section": section }
161
+
162
+ def make_classification_inference(self, temperature: float) -> dict:
163
+ genre_and_section_prompt = self.tokens_to_classification_prompt(self.__events_tokens)
164
+ prompt = self.__tokenizer.encode(genre_and_section_prompt, add_special_tokens=True, return_tensors="pt")
165
+
166
+ generated_section_genre_sequence = self.__model.generate(
167
+ prompt,
168
+ max_length=1024,
169
+ do_sample=True,
170
+ temperature=0.1,
171
+ num_return_sequences=1,
172
+ )
173
+
174
+ section_genre_result = self.__tokenizer.decode(generated_section_genre_sequence[0], skip_special_tokens=True)
175
+ assert len(section_genre_result) > 0, "Empty result"
176
+
177
+ genre_section_data = self.tokens_to_genre_and_section_information(section_genre_result.split(" "))
178
+ return genre_section_data
179
+
180
+ def make_section_events_inference(self, prompt: str, section: str, temperature: float, known_section: str) -> [[int, int]]:
181
+ tokenised_prompt = self.__tokenizer.encode(prompt, add_special_tokens=True, return_tensors="pt")
182
+ assert len(tokenised_prompt[0]) <= 1024, "Prompt length exceeds maximum sequence length"
183
+
184
+ generated_sequence = self.__model.generate(
185
+ tokenised_prompt,
186
+ max_length=1024,
187
+ do_sample=True,
188
+ temperature=temperature,
189
+ num_return_sequences=1,
190
+ )
191
+
192
+ result = self.__tokenizer.decode(
193
+ generated_sequence[0], skip_special_tokens=True
194
+ )
195
+
196
+ events = self.tokens_to_section_events(tokens=result.split(" "), section=section)
197
+ # Fallback option when inference fails (sometimes the model generates a sequence that doesn't contain the section)
198
+ if len(events) == 0:
199
+ events = self.tokens_to_section_events(tokens=result.split(" "), section=known_section, step_offset=self.section_to_step_offset(section=section))
200
+
201
+ assert len(events) > 0, "Empty result"
202
+
203
+ return events
model/config.json CHANGED
@@ -33,7 +33,7 @@
33
  }
34
  },
35
  "torch_dtype": "float32",
36
- "transformers_version": "4.27.4",
37
  "use_cache": true,
38
- "vocab_size": 50257
39
  }
 
33
  }
34
  },
35
  "torch_dtype": "float32",
36
+ "transformers_version": "4.28.1",
37
  "use_cache": true,
38
+ "vocab_size": 50321
39
  }
model/generation_config.json CHANGED
@@ -2,5 +2,5 @@
2
  "_from_model_config": true,
3
  "bos_token_id": 50256,
4
  "eos_token_id": 50256,
5
- "transformers_version": "4.27.4"
6
  }
 
2
  "_from_model_config": true,
3
  "bos_token_id": 50256,
4
  "eos_token_id": 50256,
5
+ "transformers_version": "4.28.1"
6
  }
model/pytorch_model.bin CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:9fa41174f8226daad7dca89c28f3b94ffc20b668dd85b4450d65f43bb07a23a9
3
- size 510398013
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f1e51cee355f39d25c000d17d50c4313dd1787f086ce23191b8a495f9c33a82b
3
+ size 510594621
model/training_args.bin CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:4126f3661406a7f858e5e8758fd29a8517d37c711dffa94648b6153e41ed1659
3
- size 3643
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:94b82dda5c87ea468fb088c62a5c04ce9158aed8f35b34ed6e7ab193f4cb4c8f
3
+ size 3707
requirements.txt CHANGED
@@ -1,3 +1,4 @@
 
1
  MIDIUtil
2
  transformers
3
- torch
 
1
+ gradio
2
  MIDIUtil
3
  transformers
4
+ torch
tokenizer/added_tokens.json ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "\n\n###\n\n": 50320,
3
+ "genre:DHouse": 50314,
4
+ "genre:Trap": 50313,
5
+ "pitch:36": 50289,
6
+ "pitch:37": 50290,
7
+ "pitch:38": 50291,
8
+ "pitch:39": 50292,
9
+ "pitch:40": 50293,
10
+ "pitch:41": 50294,
11
+ "pitch:42": 50295,
12
+ "pitch:43": 50296,
13
+ "pitch:44": 50297,
14
+ "pitch:45": 50298,
15
+ "pitch:46": 50299,
16
+ "pitch:47": 50300,
17
+ "pitch:48": 50301,
18
+ "pitch:49": 50302,
19
+ "pitch:50": 50303,
20
+ "pitch:51": 50304,
21
+ "pitch:52": 50305,
22
+ "pitch:53": 50306,
23
+ "pitch:54": 50307,
24
+ "pitch:55": 50308,
25
+ "pitch:56": 50309,
26
+ "pitch:57": 50310,
27
+ "pitch:58": 50311,
28
+ "pitch:59": 50312,
29
+ "section:a": 50315,
30
+ "section:b": 50316,
31
+ "section:c": 50317,
32
+ "section:d": 50318,
33
+ "step:0": 50257,
34
+ "step:1": 50258,
35
+ "step:10": 50267,
36
+ "step:11": 50268,
37
+ "step:12": 50269,
38
+ "step:13": 50270,
39
+ "step:14": 50271,
40
+ "step:15": 50272,
41
+ "step:16": 50273,
42
+ "step:17": 50274,
43
+ "step:18": 50275,
44
+ "step:19": 50276,
45
+ "step:2": 50259,
46
+ "step:20": 50277,
47
+ "step:21": 50278,
48
+ "step:22": 50279,
49
+ "step:23": 50280,
50
+ "step:24": 50281,
51
+ "step:25": 50282,
52
+ "step:26": 50283,
53
+ "step:27": 50284,
54
+ "step:28": 50285,
55
+ "step:29": 50286,
56
+ "step:3": 50260,
57
+ "step:30": 50287,
58
+ "step:31": 50288,
59
+ "step:4": 50261,
60
+ "step:5": 50262,
61
+ "step:6": 50263,
62
+ "step:7": 50264,
63
+ "step:8": 50265,
64
+ "step:9": 50266,
65
+ "which_genre_section": 50319
66
+ }
tokenizer/merges.txt ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer/special_tokens_map.json ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": {
3
+ "content": "<|endoftext|>",
4
+ "lstrip": false,
5
+ "normalized": true,
6
+ "rstrip": false,
7
+ "single_word": false
8
+ },
9
+ "eos_token": {
10
+ "content": "<|endoftext|>",
11
+ "lstrip": false,
12
+ "normalized": true,
13
+ "rstrip": false,
14
+ "single_word": false
15
+ },
16
+ "pad_token": "<|endoftext|>",
17
+ "unk_token": {
18
+ "content": "<|endoftext|>",
19
+ "lstrip": false,
20
+ "normalized": true,
21
+ "rstrip": false,
22
+ "single_word": false
23
+ }
24
+ }
tokenizer/tokenizer_config.json ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_bos_token": false,
3
+ "add_prefix_space": false,
4
+ "bos_token": {
5
+ "__type": "AddedToken",
6
+ "content": "<|endoftext|>",
7
+ "lstrip": false,
8
+ "normalized": true,
9
+ "rstrip": false,
10
+ "single_word": false
11
+ },
12
+ "clean_up_tokenization_spaces": true,
13
+ "eos_token": {
14
+ "__type": "AddedToken",
15
+ "content": "<|endoftext|>",
16
+ "lstrip": false,
17
+ "normalized": true,
18
+ "rstrip": false,
19
+ "single_word": false
20
+ },
21
+ "errors": "replace",
22
+ "model_max_length": 1024,
23
+ "pad_token": null,
24
+ "tokenizer_class": "GPT2Tokenizer",
25
+ "unk_token": {
26
+ "__type": "AddedToken",
27
+ "content": "<|endoftext|>",
28
+ "lstrip": false,
29
+ "normalized": true,
30
+ "rstrip": false,
31
+ "single_word": false
32
+ }
33
+ }
tokenizer/vocab.json ADDED
The diff for this file is too large to render. See raw diff