Update app.py
Browse files
app.py
CHANGED
|
@@ -4,6 +4,7 @@ os.environ["KERAS_BACKEND"] ="tensorflow"
|
|
| 4 |
import random
|
| 5 |
import keras
|
| 6 |
import gradio as gr
|
|
|
|
| 7 |
from src.inference import generate_chorale, draw_random_sample
|
| 8 |
from src.dataset import NoteEncoder
|
| 9 |
from src.metrics import Preplexity
|
|
@@ -21,6 +22,7 @@ wav_path = os.path.join(AUDIO_SAMPLES_PATH, "sample.wav")
|
|
| 21 |
|
| 22 |
|
| 23 |
### DOWNLOAD SF2 MUSIC FONT ###
|
|
|
|
| 24 |
sf2_download_path = keras.utils.get_file(
|
| 25 |
"FluidR3_GM.zip",
|
| 26 |
"https://keymusician01.s3.amazonaws.com/FluidR3_GM.zip",
|
|
@@ -29,15 +31,23 @@ sf2_download_path = keras.utils.get_file(
|
|
| 29 |
cache_subdir= ""
|
| 30 |
)
|
| 31 |
SF2_PATH = os.path.join(sf2_download_path, "FluidR3_GM.sf2")
|
|
|
|
| 32 |
|
| 33 |
|
| 34 |
### LOAD MODEL & ENCODERS ###
|
|
|
|
|
|
|
| 35 |
model = keras.models.load_model(os.path.join(MODEL_PATH, "bach_model.keras"),
|
| 36 |
custom_objects={"Preplexity": Preplexity})
|
|
|
|
|
|
|
|
|
|
| 37 |
note2id, id2note, vocab = NoteEncoder(vocab_path=ARTIFACTS_PATH, samples_path=None)
|
|
|
|
| 38 |
|
| 39 |
|
| 40 |
### GRADIO ASSETS ###
|
|
|
|
| 41 |
css = load_css()
|
| 42 |
english_summary = load_markdown("english_summary")
|
| 43 |
persian_summary = load_markdown("persian_summary")
|
|
@@ -45,6 +55,7 @@ english_help = load_markdown("english_help")
|
|
| 45 |
persian_help = load_markdown("persian_help")
|
| 46 |
english_title = "# BachNet: AI-Generated Bach Music"
|
| 47 |
persian_title = "# باخنت: خلق موسیقی مشابه باخ با هوش مصنوعی"
|
|
|
|
| 48 |
|
| 49 |
|
| 50 |
### GENERATION FUNCTIONS ###
|
|
@@ -53,7 +64,9 @@ def pick_random_seed():
|
|
| 53 |
|
| 54 |
def generate_fn(seed_path, seed_len, gen_len, temp):
|
| 55 |
sample_rows = slice(0, seed_len)
|
| 56 |
-
|
|
|
|
|
|
|
| 57 |
generate_chorale(
|
| 58 |
model=model,
|
| 59 |
sample_seed_path=seed_path,
|
|
@@ -64,8 +77,14 @@ def generate_fn(seed_path, seed_len, gen_len, temp):
|
|
| 64 |
temperature=temp,
|
| 65 |
sample_seed_rows=sample_rows
|
| 66 |
)
|
| 67 |
-
|
|
|
|
|
|
|
| 68 |
midi_to_wave(midi_file_path=midi_path, SF2_PATH=SF2_PATH, wave_path=wav_path)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 69 |
return wav_path
|
| 70 |
|
| 71 |
def set_english():
|
|
@@ -121,4 +140,4 @@ with gr.Blocks(css=css, title="BachNet") as demo:
|
|
| 121 |
|
| 122 |
### LAUNCH APP ###
|
| 123 |
if __name__ == "__main__":
|
| 124 |
-
demo.launch()
|
|
|
|
| 4 |
import random
|
| 5 |
import keras
|
| 6 |
import gradio as gr
|
| 7 |
+
import time # <-- added
|
| 8 |
from src.inference import generate_chorale, draw_random_sample
|
| 9 |
from src.dataset import NoteEncoder
|
| 10 |
from src.metrics import Preplexity
|
|
|
|
| 22 |
|
| 23 |
|
| 24 |
### DOWNLOAD SF2 MUSIC FONT ###
|
| 25 |
+
t0 = time.time()
|
| 26 |
sf2_download_path = keras.utils.get_file(
|
| 27 |
"FluidR3_GM.zip",
|
| 28 |
"https://keymusician01.s3.amazonaws.com/FluidR3_GM.zip",
|
|
|
|
| 31 |
cache_subdir= ""
|
| 32 |
)
|
| 33 |
SF2_PATH = os.path.join(sf2_download_path, "FluidR3_GM.sf2")
|
| 34 |
+
print(f"SF2 font download/extract took {time.time()-t0:.2f}s")
|
| 35 |
|
| 36 |
|
| 37 |
### LOAD MODEL & ENCODERS ###
|
| 38 |
+
print("loading model...")
|
| 39 |
+
t0 = time.time()
|
| 40 |
model = keras.models.load_model(os.path.join(MODEL_PATH, "bach_model.keras"),
|
| 41 |
custom_objects={"Preplexity": Preplexity})
|
| 42 |
+
print(f"model loaded in {time.time()-t0:.2f}s")
|
| 43 |
+
|
| 44 |
+
t1 = time.time()
|
| 45 |
note2id, id2note, vocab = NoteEncoder(vocab_path=ARTIFACTS_PATH, samples_path=None)
|
| 46 |
+
print(f"NoteEncoder init took {time.time()-t1:.2f}s")
|
| 47 |
|
| 48 |
|
| 49 |
### GRADIO ASSETS ###
|
| 50 |
+
t2 = time.time()
|
| 51 |
css = load_css()
|
| 52 |
english_summary = load_markdown("english_summary")
|
| 53 |
persian_summary = load_markdown("persian_summary")
|
|
|
|
| 55 |
persian_help = load_markdown("persian_help")
|
| 56 |
english_title = "# BachNet: AI-Generated Bach Music"
|
| 57 |
persian_title = "# باخنت: خلق موسیقی مشابه باخ با هوش مصنوعی"
|
| 58 |
+
print(f"Loaded CSS/markdown assets in {time.time()-t2:.2f}s")
|
| 59 |
|
| 60 |
|
| 61 |
### GENERATION FUNCTIONS ###
|
|
|
|
| 64 |
|
| 65 |
def generate_fn(seed_path, seed_len, gen_len, temp):
|
| 66 |
sample_rows = slice(0, seed_len)
|
| 67 |
+
|
| 68 |
+
print("=== Generation started ===")
|
| 69 |
+
t0 = time.time()
|
| 70 |
generate_chorale(
|
| 71 |
model=model,
|
| 72 |
sample_seed_path=seed_path,
|
|
|
|
| 77 |
temperature=temp,
|
| 78 |
sample_seed_rows=sample_rows
|
| 79 |
)
|
| 80 |
+
t1 = time.time()
|
| 81 |
+
print(f"generate_chorale took {t1 - t0:.2f}s")
|
| 82 |
+
|
| 83 |
midi_to_wave(midi_file_path=midi_path, SF2_PATH=SF2_PATH, wave_path=wav_path)
|
| 84 |
+
t2 = time.time()
|
| 85 |
+
print(f"midi_to_wave took {t2 - t1:.2f}s")
|
| 86 |
+
print(f"Total generate_fn time {t2 - t0:.2f}s")
|
| 87 |
+
|
| 88 |
return wav_path
|
| 89 |
|
| 90 |
def set_english():
|
|
|
|
| 140 |
|
| 141 |
### LAUNCH APP ###
|
| 142 |
if __name__ == "__main__":
|
| 143 |
+
demo.launch()
|