badd9yang commited on
Commit
c045d4a
·
verified ·
1 Parent(s): 3e98f0c

Update utils.py

Browse files
Files changed (1) hide show
  1. utils.py +230 -0
utils.py CHANGED
@@ -0,0 +1,230 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+ from typing import List, Tuple
4
+
5
+ import gradio as gr
6
+ from transformers import AutoTokenizer, AutoModelForCausalLM
7
+ import note_seq
8
+ from matplotlib.figure import Figure
9
+ from numpy import ndarray
10
+ import torch
11
+
12
+ from constants import GM_INSTRUMENTS, SAMPLE_RATE
13
+ from string_to_notes import token_sequence_to_note_sequence
14
+ from model import get_model_and_tokenizer
15
+
16
+
17
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
18
+
19
+ # Load the tokenizer and the model
20
+ model, tokenizer = get_model_and_tokenizer()
21
+
22
+
23
+ def create_seed_string(genre: str = "OTHER") -> str:
24
+ """
25
+ Creates a seed string for generating a new piece.
26
+ Args:
27
+ genre (str, optional): The genre of the piece. Defaults to "OTHER".
28
+ Returns:
29
+ str: The seed string.
30
+ """
31
+ if genre == "RANDOM":
32
+ seed_string = "PIECE_START"
33
+ else:
34
+ seed_string = f"PIECE_START GENRE={genre} TRACK_START"
35
+ return seed_string
36
+
37
+
38
+ def get_instruments(text_sequence: str) -> List[str]:
39
+ """
40
+ Extracts the list of instruments from a text sequence.
41
+ Args:
42
+ text_sequence (str): The text sequence.
43
+ Returns:
44
+ List[str]: The list of instruments.
45
+ """
46
+ instruments = []
47
+ parts = text_sequence.split()
48
+ for part in parts:
49
+ if part.startswith("INST="):
50
+ if part[5:] == "DRUMS":
51
+ instruments.append("Drums")
52
+ else:
53
+ index = int(part[5:])
54
+ instruments.append(GM_INSTRUMENTS[index])
55
+ return instruments
56
+
57
+
58
+ def generate_new_instrument(seed: str, temp: float = 0.75) -> str:
59
+ """
60
+ Generates a new instrument sequence from a given seed and temperature.
61
+ Args:
62
+ seed (str): The seed string for the generation.
63
+ temp (float, optional): The temperature for the generation, which controls the randomness. Defaults to 0.75.
64
+ Returns:
65
+ str: The generated instrument sequence.
66
+ """
67
+ seed_length = len(tokenizer.encode(seed))
68
+
69
+ while True:
70
+ # Encode the conditioning tokens.
71
+ input_ids = tokenizer.encode(seed, return_tensors="pt")
72
+
73
+ # Move the input_ids tensor to the same device as the model
74
+ input_ids = input_ids.to(model.device)
75
+
76
+ # Generate more tokens.
77
+ eos_token_id = tokenizer.encode("TRACK_END")[0]
78
+ generated_ids = model.generate(
79
+ input_ids,
80
+ max_new_tokens=2048,
81
+ do_sample=True,
82
+ temperature=temp,
83
+ eos_token_id=eos_token_id,
84
+ )
85
+ generated_sequence = tokenizer.decode(generated_ids[0])
86
+
87
+ # Check if the generated sequence contains "NOTE_ON" beyond the seed
88
+ new_generated_sequence = tokenizer.decode(generated_ids[0][seed_length:])
89
+ if "NOTE_ON" in new_generated_sequence:
90
+ return generated_sequence
91
+
92
+
93
+ def get_outputs_from_string(
94
+ generated_sequence: str, qpm: int = 120
95
+ ) -> Tuple[ndarray, str, Figure, str, str]:
96
+ """
97
+ Converts a generated sequence into various output formats including audio, MIDI, plot, etc.
98
+ Args:
99
+ generated_sequence (str): The generated sequence of tokens.
100
+ qpm (int, optional): The quarter notes per minute. Defaults to 120.
101
+ Returns:
102
+ Tuple[ndarray, str, Figure, str, str]: The audio waveform, MIDI file name, plot figure,
103
+ instruments string, and number of tokens string.
104
+ """
105
+ instruments = get_instruments(generated_sequence)
106
+ instruments_str = "\n".join(f"- {instrument}" for instrument in instruments)
107
+ note_sequence = token_sequence_to_note_sequence(generated_sequence, qpm=qpm)
108
+
109
+ synth = note_seq.fluidsynth
110
+ array_of_floats = synth(note_sequence, sample_rate=SAMPLE_RATE)
111
+ int16_data = note_seq.audio_io.float_samples_to_int16(array_of_floats)
112
+ fig = note_seq.plot_sequence(note_sequence, show_figure=False)
113
+ num_tokens = str(len(generated_sequence.split()))
114
+ audio = gr.make_waveform((SAMPLE_RATE, int16_data))
115
+ note_seq.note_sequence_to_midi_file(note_sequence, "midi_ouput.mid")
116
+ return audio, "midi_ouput.mid", fig, instruments_str, num_tokens
117
+
118
+
119
+ def remove_last_instrument(
120
+ text_sequence: str, qpm: int = 120
121
+ ) -> Tuple[ndarray, str, Figure, str, str, str]:
122
+ """
123
+ Removes the last instrument from a song string and returns the various output formats.
124
+ Args:
125
+ text_sequence (str): The song string.
126
+ qpm (int, optional): The quarter notes per minute. Defaults to 120.
127
+ Returns:
128
+ Tuple[ndarray, str, Figure, str, str, str]: The audio waveform, MIDI file name, plot figure,
129
+ instruments string, new song string, and number of tokens string.
130
+ """
131
+ # We split the song into tracks by splitting on 'TRACK_START'
132
+ tracks = text_sequence.split("TRACK_START")
133
+ # We keep all tracks except the last one
134
+ modified_tracks = tracks[:-1]
135
+ # We join the tracks back together, adding back the 'TRACK_START' that was removed by split
136
+ new_song = "TRACK_START".join(modified_tracks)
137
+
138
+ if len(tracks) == 2:
139
+ # There is only one instrument, so start from scratch
140
+ audio, midi_file, fig, instruments_str, new_song, num_tokens = generate_song(
141
+ text_sequence=new_song
142
+ )
143
+ elif len(tracks) == 1:
144
+ # No instrument so start from empty sequence
145
+ audio, midi_file, fig, instruments_str, new_song, num_tokens = generate_song(
146
+ text_sequence=""
147
+ )
148
+ else:
149
+ audio, midi_file, fig, instruments_str, num_tokens = get_outputs_from_string(
150
+ new_song, qpm
151
+ )
152
+
153
+ return audio, midi_file, fig, instruments_str, new_song, num_tokens
154
+
155
+
156
+ def regenerate_last_instrument(
157
+ text_sequence: str, qpm: int = 120
158
+ ) -> Tuple[ndarray, str, Figure, str, str, str]:
159
+ """
160
+ Regenerates the last instrument in a song string and returns the various output formats.
161
+ Args:
162
+ text_sequence (str): The song string.
163
+ qpm (int, optional): The quarter notes per minute. Defaults to 120.
164
+ Returns:
165
+ Tuple[ndarray, str, Figure, str, str, str]: The audio waveform, MIDI file name, plot figure,
166
+ instruments string, new song string, and number of tokens string.
167
+ """
168
+ last_inst_index = text_sequence.rfind("INST=")
169
+ if last_inst_index == -1:
170
+ # No instrument so start from empty sequence
171
+ audio, midi_file, fig, instruments_str, new_song, num_tokens = generate_song(
172
+ text_sequence="", qpm=qpm
173
+ )
174
+ else:
175
+ # Take it from the last instrument and continue generation
176
+ next_space_index = text_sequence.find(" ", last_inst_index)
177
+ new_seed = text_sequence[:next_space_index]
178
+ audio, midi_file, fig, instruments_str, new_song, num_tokens = generate_song(
179
+ text_sequence=new_seed, qpm=qpm
180
+ )
181
+ return audio, midi_file, fig, instruments_str, new_song, num_tokens
182
+
183
+
184
+ def change_tempo(
185
+ text_sequence: str, qpm: int
186
+ ) -> Tuple[ndarray, str, Figure, str, str, str]:
187
+ """
188
+ Changes the tempo of a song string and returns the various output formats.
189
+ Args:
190
+ text_sequence (str): The song string.
191
+ qpm (int): The new quarter notes per minute.
192
+ Returns:
193
+ Tuple[ndarray, str, Figure, str, str, str]: The audio waveform, MIDI file name, plot figure,
194
+ instruments string, text sequence, and number of tokens string.
195
+ """
196
+ audio, midi_file, fig, instruments_str, num_tokens = get_outputs_from_string(
197
+ text_sequence, qpm=qpm
198
+ )
199
+ return audio, midi_file, fig, instruments_str, text_sequence, num_tokens
200
+
201
+
202
+ def generate_song(
203
+ genre: str = "OTHER",
204
+ temp: float = 0.75,
205
+ text_sequence: str = "",
206
+ qpm: int = 120,
207
+ ) -> Tuple[ndarray, str, Figure, str, str, str]:
208
+ """
209
+ Generates a song given a genre, temperature, initial text sequence, and tempo.
210
+ Args:
211
+ model (AutoModelForCausalLM): The pretrained model used for generating the sequences.
212
+ tokenizer (AutoTokenizer): The tokenizer used to encode and decode the sequences.
213
+ genre (str, optional): The genre of the song. Defaults to "OTHER".
214
+ temp (float, optional): The temperature for the generation, which controls the randomness. Defaults to 0.75.
215
+ text_sequence (str, optional): The initial text sequence for the song. Defaults to "".
216
+ qpm (int, optional): The quarter notes per minute. Defaults to 120.
217
+ Returns:
218
+ Tuple[ndarray, str, Figure, str, str, str]: The audio waveform, MIDI file name, plot figure,
219
+ instruments string, generated song string, and number of tokens string.
220
+ """
221
+ if text_sequence == "":
222
+ seed_string = create_seed_string(genre)
223
+ else:
224
+ seed_string = text_sequence
225
+
226
+ generated_sequence = generate_new_instrument(seed=seed_string, temp=temp)
227
+ audio, midi_file, fig, instruments_str, num_tokens = get_outputs_from_string(
228
+ generated_sequence, qpm
229
+ )
230
+ return audio, midi_file, fig, instruments_str, generated_sequence, num_tokens