Update app.py
Browse files
app.py
CHANGED
|
@@ -1,10 +1,11 @@
|
|
| 1 |
### IMPORTS ###
|
| 2 |
import os
|
| 3 |
os.environ["KERAS_BACKEND"] ="tensorflow"
|
|
|
|
| 4 |
import random
|
| 5 |
import keras
|
| 6 |
import gradio as gr
|
| 7 |
-
import time
|
| 8 |
from src.inference import generate_chorale, draw_random_sample
|
| 9 |
from src.dataset import NoteEncoder
|
| 10 |
from src.metrics import Preplexity
|
|
@@ -60,7 +61,8 @@ print(f"Loaded CSS/markdown assets in {time.time()-t2:.2f}s")
|
|
| 60 |
|
| 61 |
### GENERATION FUNCTIONS ###
|
| 62 |
def pick_random_seed():
|
| 63 |
-
|
|
|
|
| 64 |
|
| 65 |
def generate_fn(seed_path, seed_len, gen_len, temp):
|
| 66 |
sample_rows = slice(0, seed_len)
|
|
@@ -113,10 +115,11 @@ with gr.Blocks(css=css, title="BachNet") as demo:
|
|
| 113 |
gr.Markdown("## Customize Your Chorale")
|
| 114 |
with gr.Row():
|
| 115 |
sample_seed_btn = gr.Button("Pick Random Seed", variant="primary")
|
|
|
|
| 116 |
seed_path_box = gr.Textbox(label="Selected Seed Path", interactive=False)
|
| 117 |
|
| 118 |
seed_len_slider = gr.Slider(40, 80, 50, step=1, label="Seed Length")
|
| 119 |
-
gen_len_slider = gr.Slider(20, 100, 30, step=1, label="Generated Length")
|
| 120 |
temp_slider = gr.Slider(0.5, 1.8, 0.9, step=0.1, label="Temperature")
|
| 121 |
|
| 122 |
generate_btn = gr.Button("Generate", variant="primary")
|
|
@@ -130,11 +133,10 @@ with gr.Blocks(css=css, title="BachNet") as demo:
|
|
| 130 |
|
| 131 |
|
| 132 |
### EVENTS ###
|
| 133 |
-
demo.load(pick_random_seed, outputs=seed_path_box)
|
| 134 |
-
sample_seed_btn.click(pick_random_seed, outputs=seed_path_box)
|
| 135 |
-
generate_btn.click(generate_fn, inputs=[
|
| 136 |
-
|
| 137 |
-
|
| 138 |
english_btn.click(set_english, outputs=[title_md, summary_md, help_md])
|
| 139 |
persian_btn.click(set_persian, outputs=[title_md, summary_md, help_md])
|
| 140 |
|
|
|
|
| 1 |
### IMPORTS ###
|
| 2 |
import os
|
| 3 |
os.environ["KERAS_BACKEND"] ="tensorflow"
|
| 4 |
+
os.environ["TF_ENABLE_XLA"] = "0"
|
| 5 |
import random
|
| 6 |
import keras
|
| 7 |
import gradio as gr
|
| 8 |
+
import time
|
| 9 |
from src.inference import generate_chorale, draw_random_sample
|
| 10 |
from src.dataset import NoteEncoder
|
| 11 |
from src.metrics import Preplexity
|
|
|
|
| 61 |
|
| 62 |
### GENERATION FUNCTIONS ###
|
| 63 |
def pick_random_seed():
|
| 64 |
+
path = draw_random_sample(VAL_PATH, seed=random.randint(0, 9999))
|
| 65 |
+
return path, os.path.basename(path)
|
| 66 |
|
| 67 |
def generate_fn(seed_path, seed_len, gen_len, temp):
|
| 68 |
sample_rows = slice(0, seed_len)
|
|
|
|
| 115 |
gr.Markdown("## Customize Your Chorale")
|
| 116 |
with gr.Row():
|
| 117 |
sample_seed_btn = gr.Button("Pick Random Seed", variant="primary")
|
| 118 |
+
hidden_full_path = gr.State()
|
| 119 |
seed_path_box = gr.Textbox(label="Selected Seed Path", interactive=False)
|
| 120 |
|
| 121 |
seed_len_slider = gr.Slider(40, 80, 50, step=1, label="Seed Length")
|
| 122 |
+
gen_len_slider = gr.Slider(20, 100, 30, step=1, label="Generated Length (Chords)")
|
| 123 |
temp_slider = gr.Slider(0.5, 1.8, 0.9, step=0.1, label="Temperature")
|
| 124 |
|
| 125 |
generate_btn = gr.Button("Generate", variant="primary")
|
|
|
|
| 133 |
|
| 134 |
|
| 135 |
### EVENTS ###
|
| 136 |
+
demo.load(pick_random_seed, outputs=[hidden_full_path, seed_path_box])
|
| 137 |
+
sample_seed_btn.click(pick_random_seed, outputs=[hidden_full_path, seed_path_box])
|
| 138 |
+
generate_btn.click(generate_fn, inputs=[hidden_full_path, seed_len_slider, gen_len_slider, temp_slider],
|
| 139 |
+
outputs=audio_player)
|
|
|
|
| 140 |
english_btn.click(set_english, outputs=[title_md, summary_md, help_md])
|
| 141 |
persian_btn.click(set_persian, outputs=[title_md, summary_md, help_md])
|
| 142 |
|