hoom4n commited on
Commit
5201951
·
verified ·
1 Parent(s): cfa41f1

Upload 4 files

Browse files
Files changed (4) hide show
  1. app.py +122 -0
  2. packages.txt +1 -0
  3. requirements.txt +4 -0
  4. train.py +33 -0
app.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ### IMPORTS ###
2
+ import os
3
+ os.environ["KERAS_BACKEND"] ="tensorflow"
4
+ import random
5
+ import keras
6
+ import gradio as gr
7
+ from src.inference import generate_chorale, draw_random_sample
8
+ from src.dataset import NoteEncoder
9
+ from src.metrics import Preplexity
10
+ from src.config import URL
11
+ from src.utils import get_dataset_path, midi_to_wave, load_css, load_markdown
12
+
13
+
14
+ ### SETUP ###
15
+ ROOT_DIR = os.getcwd()
16
+ TRAIN_PATH, VAL_PATH, ARTIFACTS_PATH, MODEL_PATH = get_dataset_path(ROOT_DIR, URL)
17
+ AUDIO_SAMPLES_PATH = os.path.join(ROOT_DIR, "samples")
18
+ os.makedirs(AUDIO_SAMPLES_PATH, exist_ok=True)
19
+ midi_path = os.path.join(AUDIO_SAMPLES_PATH, "sample.mid")
20
+ wav_path = os.path.join(AUDIO_SAMPLES_PATH, "sample.wav")
21
+
22
+
23
+ ### DOWNLOAD SF2 MUSIC FONT ###
24
+ sf2_download_path = keras.utils.get_file(
25
+ "FluidR3_GM.zip",
26
+ "https://keymusician01.s3.amazonaws.com/FluidR3_GM.zip",
27
+ extract= True,
28
+ cache_dir= ARTIFACTS_PATH,
29
+ cache_subdir= ""
30
+ )
31
+ SF2_PATH = os.path.join(sf2_download_path, "FluidR3_GM.sf2")
32
+
33
+
34
+ ### LOAD MODEL & ENCODERS ###
35
+ model = keras.models.load_model(os.path.join(MODEL_PATH, "bach_model.keras"),
36
+ custom_objects={"Preplexity": Preplexity})
37
+ note2id, id2note, vocab = NoteEncoder(vocab_path=ARTIFACTS_PATH, samples_path=None)
38
+
39
+
40
+ ### GRADIO ASSETS ###
41
+ css = load_css()
42
+ english_summary = load_markdown("english_summary")
43
+ persian_summary = load_markdown("persian_summary")
44
+ english_help = load_markdown("english_help")
45
+ persian_help = load_markdown("persian_help")
46
+ english_title = "# BachNet: AI-Generated Bach Music"
47
+ persian_title = "# باخ‌نت: خلق موسیقی مشابه باخ با هوش مصنوعی"
48
+
49
+
50
+ ### GENERATION FUNCTIONS ###
51
+ def pick_random_seed():
52
+ return draw_random_sample(VAL_PATH, seed=random.randint(0, 9999))
53
+
54
+ def generate_fn(seed_path, seed_len, gen_len, temp):
55
+ sample_rows = slice(0, seed_len)
56
+
57
+ generate_chorale(
58
+ model=model,
59
+ sample_seed_path=seed_path,
60
+ note2id=note2id,
61
+ id2note=id2note,
62
+ file_name=midi_path,
63
+ max_len=gen_len,
64
+ temperature=temp,
65
+ sample_seed_rows=sample_rows
66
+ )
67
+
68
+ midi_to_wave(midi_file_path=midi_path, SF2_PATH=SF2_PATH, wave_path=wav_path)
69
+ return wav_path
70
+
71
+ def set_english():
72
+ return (gr.update(value=english_title, elem_classes=[]),
73
+ gr.update(value=english_summary, elem_classes=[]),
74
+ gr.update(value=english_help, elem_classes=[]))
75
+
76
+ def set_persian():
77
+ return (gr.update(value=persian_title, elem_classes=['persian']),
78
+ gr.update(value=persian_summary, elem_classes=['persian']),
79
+ gr.update(value=persian_help, elem_classes=['persian']))
80
+
81
+
82
+ ### GRADIO APP ###
83
+ with gr.Blocks(css=css, title="BachNet") as demo:
84
+ title_md = gr.Markdown(english_title, elem_id="title")
85
+
86
+ with gr.Row():
87
+ english_btn = gr.Button("English")
88
+ persian_btn = gr.Button("Persian (فارسی)")
89
+
90
+ summary_md = gr.Markdown(english_summary, elem_id="summary", max_height=None)
91
+
92
+ with gr.Row(variant="panel"):
93
+ with gr.Column(scale=1, variant="panel"):
94
+ gr.Markdown("## Customize Your Chorale")
95
+ with gr.Row():
96
+ sample_seed_btn = gr.Button("Pick Random Seed", variant="primary")
97
+ seed_path_box = gr.Textbox(label="Selected Seed Path", interactive=False)
98
+
99
+ seed_len_slider = gr.Slider(50, 150, 80, step=1, label="Seed Length")
100
+ gen_len_slider = gr.Slider(20, 200, 50, step=1, label="Generated Length")
101
+ temp_slider = gr.Slider(0.5, 1.8, 1.0, step=0.1, label="Temperature")
102
+
103
+ generate_btn = gr.Button("Generate", variant="primary")
104
+
105
+ with gr.Column(scale=1, variant="panel"):
106
+ gr.Markdown("## Generated Music: Listen & Download")
107
+ audio_player = gr.Audio(label="Generated Chorale", type="filepath",
108
+ interactive=False, show_download_button=True, streaming=True, autoplay=True)
109
+ help_md = gr.Markdown(english_help, elem_id="help_text")
110
+
111
+
112
+ ### EVENTS ###
113
+ sample_seed_btn.click(pick_random_seed, outputs=seed_path_box)
114
+ generate_btn.click(generate_fn, inputs=[seed_path_box, seed_len_slider, gen_len_slider, temp_slider],
115
+ outputs=audio_player)
116
+
117
+ english_btn.click(set_english, outputs=[title_md, summary_md, help_md])
118
+ persian_btn.click(set_persian, outputs=[title_md, summary_md, help_md])
119
+
120
+ ### LAUNCH APP ###
121
+ if __name__ == "__main__":
122
+ demo.launch()
packages.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ fluidsynth
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ tensorflow==2.19.0
2
+ numpy==2.1.3
3
+ gradio==5.49.0
4
+ music21==9.7.1
train.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from src.config import *
2
+ from src.dataset import NoteEncoder, seq2seq_dataset
3
+ from src.model import get_model
4
+ from src.trainer import train_model
5
+ from src.utils import get_dataset_path
6
+ import keras
7
+ import os
8
+
9
+ ### DOWNLOAD DATASET ###
10
+ ROOT_DIR = os.getcwd()
11
+ TRAIN_PATH, VAL_PATH, ARTIFACTS_PATH, MODEL_PATH = get_dataset_path(ROOT_DIR, URL)
12
+
13
+ ### REPRODUCABILITY ###
14
+ keras.utils.set_random_seed(SEED)
15
+
16
+ ### INITIALIZE MODEL & DATASET ###
17
+ note2id, id2note, vocab = NoteEncoder(samples_path=TRAIN_PATH, vocab_path=ARTIFACTS_PATH)
18
+ vocab_size = len(vocab)
19
+
20
+ train = seq2seq_dataset(TRAIN_PATH + "/*.csv",note2id, seq_len=SEQ_LEN, window_shift=WINDOW_SHIFT,
21
+ batch_size=BATCH_SIZE, shuffle_buffer=2500, seed=SEED)
22
+
23
+ val = seq2seq_dataset(VAL_PATH + "/*.csv" ,note2id, seq_len=SEQ_LEN, window_shift=WINDOW_SHIFT,
24
+ batch_size=BATCH_SIZE, shuffle_buffer=None)
25
+
26
+ bach_model = get_model(lr= LEARNING_RATE, weight_decay= WEIGHT_DECAY,
27
+ emb_in = vocab_size, emb_out = EMBEDDING_DIM,
28
+ lstm_layers = LSTM_LAYERS, lstm_units = LSTM_UNITS,
29
+ lstm_dropout = LSTM_DROPOUT, dense_units = DENSE_UNITS,
30
+ dropout = DROPOUT)
31
+
32
+ ### TRAINER ###
33
+ train_model(bach_model, train, val, N_EPOCHS, ARTIFACTS_PATH, MODEL_PATH)