File size: 5,579 Bytes
5201951 1e58d28 5201951 1e58d28 5201951 49ddb63 5201951 49ddb63 5201951 49ddb63 5201951 49ddb63 5201951 49ddb63 5201951 4883178 49ddb63 5201951 49ddb63 5201951 1e58d28 5201951 49ddb63 5201951 49ddb63 5201951 49ddb63 5201951 587c50a 5201951 1e58d28 5201951 245cd4f 1e58d28 9f3d61c 5201951 245cd4f 5201951 1e58d28 5201951 49ddb63 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 | ### IMPORTS ###
import os
os.environ["KERAS_BACKEND"] ="tensorflow"
os.environ["TF_ENABLE_XLA"] = "0"
import random
import keras
import gradio as gr
import time
from src.inference import generate_chorale, draw_random_sample
from src.dataset import NoteEncoder
from src.metrics import Preplexity
from src.config import URL
from src.utils import get_dataset_path, midi_to_wave, load_css, load_markdown
### SETUP ###
ROOT_DIR = os.getcwd()
TRAIN_PATH, VAL_PATH, ARTIFACTS_PATH, MODEL_PATH = get_dataset_path(ROOT_DIR, URL)
AUDIO_SAMPLES_PATH = os.path.join(ROOT_DIR, "samples")
os.makedirs(AUDIO_SAMPLES_PATH, exist_ok=True)
midi_path = os.path.join(AUDIO_SAMPLES_PATH, "sample.mid")
wav_path = os.path.join(AUDIO_SAMPLES_PATH, "sample.wav")
### DOWNLOAD SF2 MUSIC FONT ###
t0 = time.time()
sf2_download_path = keras.utils.get_file(
"FluidR3_GM.zip",
"https://keymusician01.s3.amazonaws.com/FluidR3_GM.zip",
extract= True,
cache_dir= ARTIFACTS_PATH,
cache_subdir= ""
)
SF2_PATH = os.path.join(sf2_download_path, "FluidR3_GM.sf2")
print(f"SF2 font download/extract took {time.time()-t0:.2f}s")
### LOAD MODEL & ENCODERS ###
print("loading model...")
t0 = time.time()
model = keras.models.load_model(os.path.join(MODEL_PATH, "bach_model.keras"),
custom_objects={"Preplexity": Preplexity})
print(f"model loaded in {time.time()-t0:.2f}s")
t1 = time.time()
note2id, id2note, vocab = NoteEncoder(vocab_path=ARTIFACTS_PATH, samples_path=None)
print(f"NoteEncoder init took {time.time()-t1:.2f}s")
### GRADIO ASSETS ###
theme = gr.themes.Soft( font=[gr.themes.GoogleFont("Vazirmatn"),"Segoe UI", "system-ui"])
t2 = time.time()
css = load_css()
english_summary = load_markdown("english_summary")
persian_summary = load_markdown("persian_summary")
english_help = load_markdown("english_help")
persian_help = load_markdown("persian_help")
english_title = "# BachNet: AI-Generated Bach Music"
persian_title = "# باخنت: خلق موسیقی مشابه باخ با هوش مصنوعی"
print(f"Loaded CSS/markdown assets in {time.time()-t2:.2f}s")
### GENERATION FUNCTIONS ###
def pick_random_seed():
path = draw_random_sample(VAL_PATH, seed=random.randint(0, 9999))
return path, os.path.basename(path)
def generate_fn(seed_path, seed_len, gen_len, temp):
sample_rows = slice(0, seed_len)
print("=== Generation started ===")
t0 = time.time()
generate_chorale(
model=model,
sample_seed_path=seed_path,
note2id=note2id,
id2note=id2note,
file_name=midi_path,
max_len=gen_len,
temperature=temp,
sample_seed_rows=sample_rows
)
t1 = time.time()
print(f"generate_chorale took {t1 - t0:.2f}s")
midi_to_wave(midi_file_path=midi_path, SF2_PATH=SF2_PATH, wave_path=wav_path)
t2 = time.time()
print(f"midi_to_wave took {t2 - t1:.2f}s")
print(f"Total generate_fn time {t2 - t0:.2f}s")
return wav_path
def set_english():
return (gr.update(value=english_title, elem_classes=[]),
gr.update(value=english_summary, elem_classes=[]),
gr.update(value=english_help, elem_classes=[]))
def set_persian():
return (gr.update(value=persian_title, elem_classes=['persian']),
gr.update(value=persian_summary, elem_classes=['persian']),
gr.update(value=persian_help, elem_classes=['persian']))
### GRADIO APP ###
with gr.Blocks(css=css, title="BachNet", theme=theme) as demo:
title_md = gr.Markdown(english_title, elem_id="title")
with gr.Row():
english_btn = gr.Button("English")
persian_btn = gr.Button("Persian (فارسی)")
summary_md = gr.Markdown(english_summary, elem_id="summary", max_height=None)
with gr.Row(variant="panel"):
with gr.Column(scale=1, variant="panel"):
gr.Markdown("## Customize Your Chorale")
with gr.Row():
sample_seed_btn = gr.Button("Pick Random Seed", variant="primary")
hidden_full_path = gr.State()
seed_path_box = gr.Textbox(label="Selected Seed Path", interactive=False)
seed_len_slider = gr.Slider(40, 80, 50, step=1, label="Seed Length")
gen_len_slider = gr.Slider(20, 100, 30, step=1, label="Generated Length (Chords)")
temp_slider = gr.Slider(0.5, 1.8, 0.9, step=0.1, label="Temperature")
generate_btn = gr.Button("Generate", variant="primary")
with gr.Column(scale=1, variant="panel"):
gr.Markdown("## Generated Music: Listen & Download")
gr.Markdown("⚠️ *Note: Running on CPU — generation may take ~15 seconds on default settings.*",elem_id="cpu_warning")
audio_player = gr.Audio(label="Generated Chorale", type="filepath",
interactive=False, show_download_button=True, streaming=True, autoplay=True)
help_md = gr.Markdown(english_help, elem_id="help_text")
### EVENTS ###
demo.load(pick_random_seed, outputs=[hidden_full_path, seed_path_box])
sample_seed_btn.click(pick_random_seed, outputs=[hidden_full_path, seed_path_box])
generate_btn.click(generate_fn, inputs=[hidden_full_path, seed_len_slider, gen_len_slider, temp_slider],
outputs=audio_player)
english_btn.click(set_english, outputs=[title_md, summary_md, help_md])
persian_btn.click(set_persian, outputs=[title_md, summary_md, help_md])
### LAUNCH APP ###
if __name__ == "__main__":
demo.launch()
|