File size: 5,579 Bytes
5201951
 
 
1e58d28
5201951
 
 
1e58d28
5201951
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49ddb63
5201951
 
 
 
 
 
 
 
49ddb63
5201951
 
 
49ddb63
 
5201951
 
49ddb63
 
 
5201951
49ddb63
5201951
 
 
4883178
49ddb63
5201951
 
 
 
 
 
 
49ddb63
5201951
 
 
 
1e58d28
 
5201951
 
 
49ddb63
 
 
5201951
 
 
 
 
 
 
 
 
 
49ddb63
 
 
5201951
49ddb63
 
 
 
5201951
 
 
 
 
 
 
 
 
 
 
 
 
 
587c50a
5201951
 
 
 
 
 
 
 
 
 
 
 
 
1e58d28
5201951
 
245cd4f
1e58d28
9f3d61c
5201951
 
 
 
 
245cd4f
5201951
 
 
 
 
 
1e58d28
 
 
 
5201951
 
 
 
 
49ddb63
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
### IMPORTS ###
import os
os.environ["KERAS_BACKEND"] ="tensorflow"
os.environ["TF_ENABLE_XLA"] = "0"
import random
import keras
import gradio as gr
import time 
from src.inference import generate_chorale, draw_random_sample
from src.dataset import NoteEncoder
from src.metrics import Preplexity
from src.config import URL
from src.utils import get_dataset_path, midi_to_wave, load_css, load_markdown


### SETUP ###
ROOT_DIR = os.getcwd()
TRAIN_PATH, VAL_PATH, ARTIFACTS_PATH, MODEL_PATH = get_dataset_path(ROOT_DIR, URL)
AUDIO_SAMPLES_PATH = os.path.join(ROOT_DIR, "samples")
os.makedirs(AUDIO_SAMPLES_PATH, exist_ok=True)
midi_path = os.path.join(AUDIO_SAMPLES_PATH, "sample.mid")
wav_path = os.path.join(AUDIO_SAMPLES_PATH, "sample.wav")


### DOWNLOAD SF2 MUSIC FONT ###
t0 = time.time()
sf2_download_path = keras.utils.get_file(
        "FluidR3_GM.zip",
        "https://keymusician01.s3.amazonaws.com/FluidR3_GM.zip",
        extract= True,
        cache_dir= ARTIFACTS_PATH,
        cache_subdir= ""
        )
SF2_PATH = os.path.join(sf2_download_path, "FluidR3_GM.sf2")
print(f"SF2 font download/extract took {time.time()-t0:.2f}s")


### LOAD MODEL & ENCODERS ###
print("loading model...")
t0 = time.time()
model = keras.models.load_model(os.path.join(MODEL_PATH, "bach_model.keras"),
                                custom_objects={"Preplexity": Preplexity})
print(f"model loaded in {time.time()-t0:.2f}s")

t1 = time.time()
note2id, id2note, vocab = NoteEncoder(vocab_path=ARTIFACTS_PATH, samples_path=None)
print(f"NoteEncoder init took {time.time()-t1:.2f}s")


### GRADIO ASSETS ###
theme = gr.themes.Soft( font=[gr.themes.GoogleFont("Vazirmatn"),"Segoe UI", "system-ui"])
t2 = time.time()
css = load_css()
english_summary = load_markdown("english_summary")
persian_summary = load_markdown("persian_summary")
english_help = load_markdown("english_help")
persian_help = load_markdown("persian_help")
english_title = "# BachNet: AI-Generated Bach Music"
persian_title = "# باخ‌نت: خلق موسیقی مشابه باخ با هوش مصنوعی"
print(f"Loaded CSS/markdown assets in {time.time()-t2:.2f}s")


### GENERATION FUNCTIONS ###
def pick_random_seed():
    path = draw_random_sample(VAL_PATH, seed=random.randint(0, 9999))
    return path, os.path.basename(path)

def generate_fn(seed_path, seed_len, gen_len, temp):
    sample_rows = slice(0, seed_len)

    print("=== Generation started ===")
    t0 = time.time()
    generate_chorale(
        model=model,
        sample_seed_path=seed_path,
        note2id=note2id,
        id2note=id2note,
        file_name=midi_path,
        max_len=gen_len,
        temperature=temp,
        sample_seed_rows=sample_rows
    )
    t1 = time.time()
    print(f"generate_chorale took {t1 - t0:.2f}s")

    midi_to_wave(midi_file_path=midi_path, SF2_PATH=SF2_PATH, wave_path=wav_path)
    t2 = time.time()
    print(f"midi_to_wave took {t2 - t1:.2f}s")
    print(f"Total generate_fn time {t2 - t0:.2f}s")

    return wav_path

def set_english():
    return (gr.update(value=english_title, elem_classes=[]),
            gr.update(value=english_summary, elem_classes=[]),
            gr.update(value=english_help, elem_classes=[]))
    
def set_persian():
    return (gr.update(value=persian_title, elem_classes=['persian']),
            gr.update(value=persian_summary, elem_classes=['persian']),
            gr.update(value=persian_help, elem_classes=['persian']))


### GRADIO APP ###
with gr.Blocks(css=css, title="BachNet", theme=theme) as demo:
    title_md = gr.Markdown(english_title, elem_id="title")
    
    with gr.Row():
        english_btn = gr.Button("English")
        persian_btn = gr.Button("Persian (فارسی)")
    
    summary_md = gr.Markdown(english_summary, elem_id="summary", max_height=None)
    
    with gr.Row(variant="panel"):
        with gr.Column(scale=1, variant="panel"):
            gr.Markdown("## Customize Your Chorale")
            with gr.Row():
                sample_seed_btn = gr.Button("Pick Random Seed", variant="primary")
                hidden_full_path = gr.State()
                seed_path_box = gr.Textbox(label="Selected Seed Path", interactive=False)
            
            seed_len_slider = gr.Slider(40, 80, 50, step=1, label="Seed Length")
            gen_len_slider = gr.Slider(20, 100, 30, step=1, label="Generated Length (Chords)")
            temp_slider = gr.Slider(0.5, 1.8, 0.9, step=0.1, label="Temperature")
            
            generate_btn = gr.Button("Generate", variant="primary")
        
        with gr.Column(scale=1, variant="panel"):
            gr.Markdown("## Generated Music: Listen & Download")
            gr.Markdown("⚠️ *Note: Running on CPU — generation may take ~15 seconds on default settings.*",elem_id="cpu_warning")
            audio_player = gr.Audio(label="Generated Chorale", type="filepath",
                                    interactive=False, show_download_button=True, streaming=True, autoplay=True)
            help_md = gr.Markdown(english_help, elem_id="help_text")


### EVENTS ###
    demo.load(pick_random_seed, outputs=[hidden_full_path, seed_path_box])
    sample_seed_btn.click(pick_random_seed, outputs=[hidden_full_path, seed_path_box])
    generate_btn.click(generate_fn, inputs=[hidden_full_path, seed_len_slider, gen_len_slider, temp_slider],
    outputs=audio_player)
    english_btn.click(set_english, outputs=[title_md, summary_md, help_md])
    persian_btn.click(set_persian, outputs=[title_md, summary_md, help_md])

### LAUNCH APP ###
if __name__ == "__main__":
    demo.launch()