BachNet / app.py
hoom4n's picture
Update app.py
4883178 verified
### IMPORTS ###
import os
os.environ["KERAS_BACKEND"] ="tensorflow"
os.environ["TF_ENABLE_XLA"] = "0"
import random
import keras
import gradio as gr
import time
from src.inference import generate_chorale, draw_random_sample
from src.dataset import NoteEncoder
from src.metrics import Preplexity
from src.config import URL
from src.utils import get_dataset_path, midi_to_wave, load_css, load_markdown
### SETUP ###
ROOT_DIR = os.getcwd()
TRAIN_PATH, VAL_PATH, ARTIFACTS_PATH, MODEL_PATH = get_dataset_path(ROOT_DIR, URL)
AUDIO_SAMPLES_PATH = os.path.join(ROOT_DIR, "samples")
os.makedirs(AUDIO_SAMPLES_PATH, exist_ok=True)
midi_path = os.path.join(AUDIO_SAMPLES_PATH, "sample.mid")
wav_path = os.path.join(AUDIO_SAMPLES_PATH, "sample.wav")
### DOWNLOAD SF2 MUSIC FONT ###
t0 = time.time()
sf2_download_path = keras.utils.get_file(
"FluidR3_GM.zip",
"https://keymusician01.s3.amazonaws.com/FluidR3_GM.zip",
extract= True,
cache_dir= ARTIFACTS_PATH,
cache_subdir= ""
)
SF2_PATH = os.path.join(sf2_download_path, "FluidR3_GM.sf2")
print(f"SF2 font download/extract took {time.time()-t0:.2f}s")
### LOAD MODEL & ENCODERS ###
print("loading model...")
t0 = time.time()
model = keras.models.load_model(os.path.join(MODEL_PATH, "bach_model.keras"),
custom_objects={"Preplexity": Preplexity})
print(f"model loaded in {time.time()-t0:.2f}s")
t1 = time.time()
note2id, id2note, vocab = NoteEncoder(vocab_path=ARTIFACTS_PATH, samples_path=None)
print(f"NoteEncoder init took {time.time()-t1:.2f}s")
### GRADIO ASSETS ###
theme = gr.themes.Soft( font=[gr.themes.GoogleFont("Vazirmatn"),"Segoe UI", "system-ui"])
t2 = time.time()
css = load_css()
english_summary = load_markdown("english_summary")
persian_summary = load_markdown("persian_summary")
english_help = load_markdown("english_help")
persian_help = load_markdown("persian_help")
english_title = "# BachNet: AI-Generated Bach Music"
persian_title = "# باخ‌نت: خلق موسیقی مشابه باخ با هوش مصنوعی"
print(f"Loaded CSS/markdown assets in {time.time()-t2:.2f}s")
### GENERATION FUNCTIONS ###
def pick_random_seed():
path = draw_random_sample(VAL_PATH, seed=random.randint(0, 9999))
return path, os.path.basename(path)
def generate_fn(seed_path, seed_len, gen_len, temp):
sample_rows = slice(0, seed_len)
print("=== Generation started ===")
t0 = time.time()
generate_chorale(
model=model,
sample_seed_path=seed_path,
note2id=note2id,
id2note=id2note,
file_name=midi_path,
max_len=gen_len,
temperature=temp,
sample_seed_rows=sample_rows
)
t1 = time.time()
print(f"generate_chorale took {t1 - t0:.2f}s")
midi_to_wave(midi_file_path=midi_path, SF2_PATH=SF2_PATH, wave_path=wav_path)
t2 = time.time()
print(f"midi_to_wave took {t2 - t1:.2f}s")
print(f"Total generate_fn time {t2 - t0:.2f}s")
return wav_path
def set_english():
return (gr.update(value=english_title, elem_classes=[]),
gr.update(value=english_summary, elem_classes=[]),
gr.update(value=english_help, elem_classes=[]))
def set_persian():
return (gr.update(value=persian_title, elem_classes=['persian']),
gr.update(value=persian_summary, elem_classes=['persian']),
gr.update(value=persian_help, elem_classes=['persian']))
### GRADIO APP ###
with gr.Blocks(css=css, title="BachNet", theme=theme) as demo:
title_md = gr.Markdown(english_title, elem_id="title")
with gr.Row():
english_btn = gr.Button("English")
persian_btn = gr.Button("Persian (فارسی)")
summary_md = gr.Markdown(english_summary, elem_id="summary", max_height=None)
with gr.Row(variant="panel"):
with gr.Column(scale=1, variant="panel"):
gr.Markdown("## Customize Your Chorale")
with gr.Row():
sample_seed_btn = gr.Button("Pick Random Seed", variant="primary")
hidden_full_path = gr.State()
seed_path_box = gr.Textbox(label="Selected Seed Path", interactive=False)
seed_len_slider = gr.Slider(40, 80, 50, step=1, label="Seed Length")
gen_len_slider = gr.Slider(20, 100, 30, step=1, label="Generated Length (Chords)")
temp_slider = gr.Slider(0.5, 1.8, 0.9, step=0.1, label="Temperature")
generate_btn = gr.Button("Generate", variant="primary")
with gr.Column(scale=1, variant="panel"):
gr.Markdown("## Generated Music: Listen & Download")
gr.Markdown("⚠️ *Note: Running on CPU — generation may take ~15 seconds on default settings.*",elem_id="cpu_warning")
audio_player = gr.Audio(label="Generated Chorale", type="filepath",
interactive=False, show_download_button=True, streaming=True, autoplay=True)
help_md = gr.Markdown(english_help, elem_id="help_text")
### EVENTS ###
demo.load(pick_random_seed, outputs=[hidden_full_path, seed_path_box])
sample_seed_btn.click(pick_random_seed, outputs=[hidden_full_path, seed_path_box])
generate_btn.click(generate_fn, inputs=[hidden_full_path, seed_len_slider, gen_len_slider, temp_slider],
outputs=audio_player)
english_btn.click(set_english, outputs=[title_md, summary_md, help_md])
persian_btn.click(set_persian, outputs=[title_md, summary_md, help_md])
### LAUNCH APP ###
if __name__ == "__main__":
demo.launch()