Update utils.py
Browse files
utils.py
CHANGED
|
@@ -18,7 +18,7 @@ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
| 18 |
model, tokenizer = get_model_and_tokenizer()
|
| 19 |
|
| 20 |
|
| 21 |
-
def create_seed_string(genre: str = "OTHER") -> str:
|
| 22 |
"""
|
| 23 |
Creates a seed string for generating a new piece.
|
| 24 |
|
|
@@ -28,10 +28,14 @@ def create_seed_string(genre: str = "OTHER") -> str:
|
|
| 28 |
Returns:
|
| 29 |
str: The seed string.
|
| 30 |
"""
|
| 31 |
-
if genre == "RANDOM":
|
| 32 |
seed_string = "PIECE_START"
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
else:
|
| 34 |
-
seed_string = f"PIECE_START GENRE={genre} TRACK_START"
|
| 35 |
return seed_string
|
| 36 |
|
| 37 |
|
|
@@ -235,11 +239,9 @@ def generate_song(
|
|
| 235 |
instruments string, generated song string, and number of tokens string.
|
| 236 |
"""
|
| 237 |
if text_sequence == "":
|
| 238 |
-
|
| 239 |
-
seed_string = create_seed_string(seed)
|
| 240 |
else:
|
| 241 |
seed_string = text_sequence
|
| 242 |
-
print(seed_string)
|
| 243 |
|
| 244 |
generated_sequence = generate_new_instrument(seed=seed_string, temp=temp)
|
| 245 |
audio, midi_file, fig, instruments_str, num_tokens = get_outputs_from_string(
|
|
|
|
| 18 |
model, tokenizer = get_model_and_tokenizer()
|
| 19 |
|
| 20 |
|
| 21 |
+
def create_seed_string(genre: str = "OTHER", artist: str = "OTHER") -> str:
|
| 22 |
"""
|
| 23 |
Creates a seed string for generating a new piece.
|
| 24 |
|
|
|
|
| 28 |
Returns:
|
| 29 |
str: The seed string.
|
| 30 |
"""
|
| 31 |
+
if genre == "RANDOM" and artist == "RANDOM":
|
| 32 |
seed_string = "PIECE_START"
|
| 33 |
+
elif genre == "RANDOM" and artist != "RANDOM":
|
| 34 |
+
seed_string = f"PIECE_START GENRE=RANDOM GENRE={artist} TRACK_START"
|
| 35 |
+
elif genre != "RANDOM" and artist == "RANDOM":
|
| 36 |
+
seed_string = f"PIECE_START GENRE={genre} GENRE=RANDOM TRACK_START"
|
| 37 |
else:
|
| 38 |
+
seed_string = f"PIECE_START GENRE={genre} GENRE={artist} TRACK_START"
|
| 39 |
return seed_string
|
| 40 |
|
| 41 |
|
|
|
|
| 239 |
instruments string, generated song string, and number of tokens string.
|
| 240 |
"""
|
| 241 |
if text_sequence == "":
|
| 242 |
+
seed_string = create_seed_string(genre, artist)
|
|
|
|
| 243 |
else:
|
| 244 |
seed_string = text_sequence
|
|
|
|
| 245 |
|
| 246 |
generated_sequence = generate_new_instrument(seed=seed_string, temp=temp)
|
| 247 |
audio, midi_file, fig, instruments_str, num_tokens = get_outputs_from_string(
|