Upload 17 files
Browse files- .gitattributes +1 -0
- artifacts/train_logs.json +51 -0
- artifacts/vocab.npy +3 -0
- assets/css/theme.css +61 -0
- assets/markdown/english_help.md +4 -0
- assets/markdown/english_summary.md +5 -0
- assets/markdown/persian_help.md +4 -0
- assets/markdown/persian_summary.md +6 -0
- data/jsb_chorales.zip +3 -0
- model/bach_model.keras +3 -0
- samples/sample.mid +0 -0
- src/config.py +14 -0
- src/dataset.py +56 -0
- src/inference.py +42 -0
- src/metrics.py +19 -0
- src/model.py +28 -0
- src/trainer.py +20 -0
- src/utils.py +48 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
model/bach_model.keras filter=lfs diff=lfs merge=lfs -text
|
artifacts/train_logs.json
ADDED
|
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"Preplexity": [
|
| 3 |
+
2.019901752471924,
|
| 4 |
+
1.3914695978164673,
|
| 5 |
+
1.28533136844635,
|
| 6 |
+
1.2096575498580933,
|
| 7 |
+
1.1621309518814087
|
| 8 |
+
],
|
| 9 |
+
"accuracy": [
|
| 10 |
+
0.7929478287696838,
|
| 11 |
+
0.9003167748451233,
|
| 12 |
+
0.9229070544242859,
|
| 13 |
+
0.9404971599578857,
|
| 14 |
+
0.9523513913154602
|
| 15 |
+
],
|
| 16 |
+
"loss": [
|
| 17 |
+
0.7067986726760864,
|
| 18 |
+
0.3346928358078003,
|
| 19 |
+
0.25587165355682373,
|
| 20 |
+
0.19564126431941986,
|
| 21 |
+
0.1558060199022293
|
| 22 |
+
],
|
| 23 |
+
"val_Preplexity": [
|
| 24 |
+
2.1313233375549316,
|
| 25 |
+
2.116178512573242,
|
| 26 |
+
2.2391598224639893,
|
| 27 |
+
2.3532228469848633,
|
| 28 |
+
2.4712650775909424
|
| 29 |
+
],
|
| 30 |
+
"val_accuracy": [
|
| 31 |
+
0.8125607967376709,
|
| 32 |
+
0.8240785598754883,
|
| 33 |
+
0.8254508376121521,
|
| 34 |
+
0.8284142017364502,
|
| 35 |
+
0.8301998376846313
|
| 36 |
+
],
|
| 37 |
+
"val_loss": [
|
| 38 |
+
0.7608373165130615,
|
| 39 |
+
0.7542304992675781,
|
| 40 |
+
0.811232328414917,
|
| 41 |
+
0.8612667322158813,
|
| 42 |
+
0.9103369116783142
|
| 43 |
+
],
|
| 44 |
+
"learning_rate": [
|
| 45 |
+
0.0010000000474974513,
|
| 46 |
+
0.0009440609137527645,
|
| 47 |
+
0.0008413951727561653,
|
| 48 |
+
0.0007079457864165306,
|
| 49 |
+
0.000562341301701963
|
| 50 |
+
]
|
| 51 |
+
}
|
artifacts/vocab.npy
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:82d14189d3504eeef49f6a39d240ffeb9793e1a11c0eb73ea0516ff47e774aff
|
| 3 |
+
size 504
|
assets/css/theme.css
ADDED
|
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
.gradio-container {
|
| 2 |
+
background: linear-gradient(to bottom right, #e0f7fa, #b3e5fc) !important;
|
| 3 |
+
}
|
| 4 |
+
.dark .gradio-container {
|
| 5 |
+
background: linear-gradient(to bottom right, #2a0a3a, #1e1a5e) !important;
|
| 6 |
+
}
|
| 7 |
+
button {
|
| 8 |
+
border-radius: 20px !important;
|
| 9 |
+
background: linear-gradient(to right, #4a90e2, #e94e77) !important;
|
| 10 |
+
color: white !important;
|
| 11 |
+
}
|
| 12 |
+
#title {
|
| 13 |
+
font-size: 2.5em !important;
|
| 14 |
+
color: #1e3a8a;
|
| 15 |
+
text-shadow: 2px 2px 4px rgba(0,0,0,0.1);
|
| 16 |
+
text-align: center;
|
| 17 |
+
margin-top: 36px;
|
| 18 |
+
margin-bottom: 10px;
|
| 19 |
+
}
|
| 20 |
+
.dark #title {
|
| 21 |
+
color: #e0f7fa !important;
|
| 22 |
+
text-shadow: 2px 2px 4px rgba(0,0,0,0.3);
|
| 23 |
+
}
|
| 24 |
+
#title.persian {
|
| 25 |
+
text-align: center !important;
|
| 26 |
+
}
|
| 27 |
+
#summary {
|
| 28 |
+
color: #334155;
|
| 29 |
+
background: rgba(255,255,255,0.8);
|
| 30 |
+
padding: 15px;
|
| 31 |
+
border-radius: 15px;
|
| 32 |
+
box-shadow: 0 4px 6px rgba(0,0,0,0.1);
|
| 33 |
+
margin-bottom: 2px;
|
| 34 |
+
text-align: justify !important;
|
| 35 |
+
}
|
| 36 |
+
.dark #summary {
|
| 37 |
+
color: #d1d5db !important;
|
| 38 |
+
background: rgba(30, 30, 46, 0.8) !important;
|
| 39 |
+
}
|
| 40 |
+
#help_text {
|
| 41 |
+
color: #1f2937;
|
| 42 |
+
background: #f0f9ff;
|
| 43 |
+
padding: 15px;
|
| 44 |
+
border-left: 5px solid #3b82f6;
|
| 45 |
+
border-radius: 12px;
|
| 46 |
+
box-shadow: 0 4px 6px rgba(0,0,0,0.05);
|
| 47 |
+
margin-bottom: 20px;
|
| 48 |
+
text-align: justify !important;
|
| 49 |
+
}
|
| 50 |
+
.dark #help_text {
|
| 51 |
+
color: #d1d5db !important;
|
| 52 |
+
background: rgba(30, 30, 46, 0.8) !important;
|
| 53 |
+
border-left: 5px solid #60a5fa !important;
|
| 54 |
+
}
|
| 55 |
+
.persian {
|
| 56 |
+
direction: rtl;
|
| 57 |
+
text-align: right;
|
| 58 |
+
}
|
| 59 |
+
#summary.persian, #help_text.persian {
|
| 60 |
+
text-align: justify !important;
|
| 61 |
+
}
|
assets/markdown/english_help.md
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
**How to Generate Music**
|
| 2 |
+
1. **Pick a seed:** Use a few seconds of an existing chorale from the validation set (unseen by the model) as the initial seed.
|
| 3 |
+
2. **Control randomness:** Adjust the **Hotness slider**—higher values produce more diverse outputs, lower values are more conservative.
|
| 4 |
+
3. **Adjust lengths:** Set the **seed length** for initial context and the **generated length** for output duration.
|
assets/markdown/english_summary.md
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
**BachNet🎵** is a production‑ready deep learning system for generating music in the style of J.S. Bach. Trained on a corpus of 382 chorales with a multi‑layer, sequence‑to‑sequence LSTM network, it learns both melodic patterns and temporal structures from sequences of 256 notes.
|
| 2 |
+
|
| 3 |
+
From a short seed segment of a chorale, BachNet can autoregressively compose entirely new pieces. Notes are sampled from a categorical distribution, with the degree of variation controlled by a temperature parameter.
|
| 4 |
+
|
| 5 |
+
The project also incorporates a fully in‑graph TensorFlow data‑streaming pipeline, enabling efficient, on‑the‑fly creation and batching of training samples. This design keeps the CPU busy preparing data while the GPU remains fully utilized for model training, maximizing both throughput and performance. Project GitHUB: [github.com/hoom4n/BachNet](https://github.com/hoom4n/BachNet)
|
assets/markdown/persian_help.md
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
**راهنمای تولید موسیقی**
|
| 2 |
+
1. با Pick Random Seed بصورت تصادفی از میان کرال های داده ولیدیشن (که در آموزش استفاده نشده اند) یکی را بعنوان دانه اولیه انتخاب کنید.
|
| 3 |
+
2. با اسلایدر Generated Length طول دلخواه موسیقی تولیدشده را مشخص کنید. توجه داشته باشید که طولانیتر شدن اثر به زمان تولید بیشتری نیاز دارد.
|
| 4 |
+
3. با اسلایدر Temperature میزان تصادفی بودن نتها را تنظیم کنید: دمای بالاتر انتخاب نتها را تصادفیتر میکند و دمای پایینتر منظمتر شدن انتخاب نتها را به همراه دارد.
|
assets/markdown/persian_summary.md
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
**باخنت🎵** سامانهای بر پایهی یادگیری ماشین است که برای آفرینش موسیقی در سبک یوهان سباستیان باخ ساخته شده. این مدل با بهرهگیری از شبکهی LSTM چندلایه و معماری دنبالهبهدنباله، روی مجموعهای شامل ۳۸۲ کرال باخ آموزش دیده و از دنبالههای ۲۵۶ نتی، هم الگوهای ملودیک و هم ساختارهای زمانی موسیقی را میآموزد.
|
| 2 |
+
|
| 3 |
+
باخنت قادر است تنها با دریافت بخشی کوتاه از یک کرال بهعنوان «بذر»، ادامهی قطعه را بهصورت خودبازگشتی بسازد و هر بار نتهای تازهای را به توالی بیفزاید. فرایند انتخاب نتها بر اساس توزیع احتمالاتی انجام میشود و میزان خلاقیت یا پیشبینیپذیری خروجی با پارامتر «دما» قابل تنظیم است.
|
| 4 |
+
|
| 5 |
+
این پروژه همچنین از یک خط لولهی دادهی درونگرافی در TensorFlow بهره میبرد که امکان تولید و دستهبندی نمونههای دنبالهبهدنباله را بهصورت برخط و کارآمد فراهم میکند. در این طراحی، پردازش دادهها بر عهدهی CPU است در حالی که GPU بهطور کامل صرف آموزش میشود؛ نتیجه آن است که هم سرعت و هم کارایی سامانه در بالاترین سطح باقی میماند.
|
| 6 |
+
لینک گیت هاب پروژه: [Github.com/hoom4n/BachNet](https://github.com/hoom4n/BachNet)
|
data/jsb_chorales.zip
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:fe69a909ee4d54fd7a3054db335c4899f89ab39552edfd1708a3ea6c062c8cb4
|
| 3 |
+
size 215768
|
model/bach_model.keras
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:837f3a94da07a6dc701066079d8c510e1d0d0757288c56a6e21e14e43c5b40bd
|
| 3 |
+
size 68031959
|
samples/sample.mid
ADDED
|
Binary file (888 Bytes). View file
|
|
|
src/config.py
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
URL = "https://github.com/Hoom4n/BachNet/raw/refs/heads/main/dataset/jsb_chorales.zip"
|
| 2 |
+
SEED = 42
|
| 3 |
+
SEQ_LEN = 256
|
| 4 |
+
WINDOW_SHIFT = 1
|
| 5 |
+
BATCH_SIZE = 256
|
| 6 |
+
LEARNING_RATE = 1e-3
|
| 7 |
+
WEIGHT_DECAY = 1e-4
|
| 8 |
+
EMBEDDING_DIM = 128
|
| 9 |
+
LSTM_LAYERS = 3
|
| 10 |
+
LSTM_UNITS = 512
|
| 11 |
+
LSTM_DROPOUT = 0.3
|
| 12 |
+
DENSE_UNITS = 256
|
| 13 |
+
DROPOUT = 0.3
|
| 14 |
+
N_EPOCHS = 10
|
src/dataset.py
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import tensorflow as tf
|
| 2 |
+
import numpy as np
|
| 3 |
+
import keras
|
| 4 |
+
import glob
|
| 5 |
+
import os
|
| 6 |
+
|
| 7 |
+
AUTOTUNE = tf.data.AUTOTUNE
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def NoteEncoder(vocab_path, samples_path=None):
|
| 11 |
+
"""Loads or builds a vocabulary from CSV note files and returns IntegerLookup layers for encoding and decoding notes"""
|
| 12 |
+
vocab_file = os.path.join(vocab_path, "vocab.npy")
|
| 13 |
+
|
| 14 |
+
if os.path.exists(vocab_file):
|
| 15 |
+
print("vocab.npy found, loading from disk...")
|
| 16 |
+
vocab = np.load(vocab_file)
|
| 17 |
+
elif samples_path is not None:
|
| 18 |
+
print("vocab.npy not found, adapting from sample files...")
|
| 19 |
+
files = glob.glob(os.path.join(samples_path, "*.csv"))
|
| 20 |
+
vocab = np.unique(np.hstack([np.loadtxt(p, delimiter=",", skiprows=1).flatten() for p in files]))
|
| 21 |
+
os.makedirs(vocab_path, exist_ok=True)
|
| 22 |
+
np.save(vocab_file, vocab)
|
| 23 |
+
print(f"vocab adapted and saved to {vocab_file}")
|
| 24 |
+
else:
|
| 25 |
+
raise ValueError("vocab file not found and samples_path not provided.")
|
| 26 |
+
|
| 27 |
+
note2id = keras.layers.IntegerLookup(num_oov_indices=0, vocabulary=vocab)
|
| 28 |
+
id2note = keras.layers.IntegerLookup(num_oov_indices=0, vocabulary=vocab, invert=True)
|
| 29 |
+
return note2id, id2note, vocab
|
| 30 |
+
|
| 31 |
+
def parse_and_flatten(line):
|
| 32 |
+
"""Parses a line of csv note data and flattens it into individual note tensors."""
|
| 33 |
+
fields = tf.io.decode_csv(line, [0,0,0,0])
|
| 34 |
+
return tf.data.Dataset.from_tensor_slices(fields)
|
| 35 |
+
|
| 36 |
+
def seq2seq_from_chorale(path, seq_len, window_shift):
|
| 37 |
+
"""creates seq2seq overlapping windows from a sequence"""
|
| 38 |
+
return tf.data.TextLineDataset(path).skip(1)\
|
| 39 |
+
.flat_map(parse_and_flatten)\
|
| 40 |
+
.window(seq_len + window_shift, shift=window_shift, drop_remainder=True)\
|
| 41 |
+
.flat_map(lambda yushi: yushi.batch(seq_len + window_shift))\
|
| 42 |
+
.map(lambda aiden: (aiden[:-window_shift] , aiden[window_shift:]), AUTOTUNE)
|
| 43 |
+
|
| 44 |
+
def seq2seq_dataset(files_path, lookup_fn, seq_len=256, window_shift=1,
|
| 45 |
+
batch_size=64, shuffle_buffer=None, seed=42):
|
| 46 |
+
"""Converts a single chorale CSV file into input–target note sequences using sliding windows."""
|
| 47 |
+
dataset = tf.data.Dataset.list_files(files_path, shuffle=False)\
|
| 48 |
+
.map(lambda geralt: seq2seq_from_chorale(geralt, seq_len, window_shift), AUTOTUNE)\
|
| 49 |
+
.flat_map(lambda joe:joe)\
|
| 50 |
+
.map(lambda inp, tar: (lookup_fn(inp), lookup_fn(tar)), AUTOTUNE)\
|
| 51 |
+
.cache()
|
| 52 |
+
|
| 53 |
+
if shuffle_buffer:
|
| 54 |
+
dataset = dataset.shuffle(shuffle_buffer, seed=seed)
|
| 55 |
+
|
| 56 |
+
return dataset.batch(batch_size).prefetch(AUTOTUNE)
|
src/inference.py
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from music21 import stream, chord
|
| 2 |
+
import tensorflow as tf
|
| 3 |
+
import keras
|
| 4 |
+
import numpy as np
|
| 5 |
+
import random
|
| 6 |
+
import glob
|
| 7 |
+
|
| 8 |
+
def predict_next_token(model, input_sequence, temperature=1, seed=42):
|
| 9 |
+
"predict next token given a context, sample from a categorical distribution controllable via temperature"
|
| 10 |
+
assert keras.ops.ndim(input_sequence) == 2, "function expects input_sequence to be (batch_size, sequence_len)"
|
| 11 |
+
logits = model.predict_on_batch(input_sequence)[:, -1, :]
|
| 12 |
+
scaled_logits = logits / temperature
|
| 13 |
+
return tf.random.categorical(scaled_logits, num_samples=1, seed=seed)
|
| 14 |
+
|
| 15 |
+
def generate_sequence(init_context, model, include_init_context=False, max_len=25, temperature=1 ,seed=42):
|
| 16 |
+
"""Generates a continuation of a given seed sequence by autoregressively sampling from the trained model."""
|
| 17 |
+
assert keras.ops.ndim(init_context) == 2, "function expects init_context to be (batch_size, sequence_len)"
|
| 18 |
+
seq_len_init_context = init_context.shape[1]
|
| 19 |
+
context = init_context
|
| 20 |
+
for _ in range(max_len * 4):
|
| 21 |
+
next_token = predict_next_token(model, context, temperature=temperature, seed=seed)
|
| 22 |
+
context = keras.ops.concatenate([context, next_token], axis=1)
|
| 23 |
+
return context if include_init_context else context[:,seq_len_init_context:]
|
| 24 |
+
|
| 25 |
+
def generate_chorale(model, sample_seed_path,note2id,id2note, file_name= "samples/chorale.mid", max_len=25, temperature=1,
|
| 26 |
+
sample_seed_rows: slice = slice(0,100), include_init_context=False, seed=42):
|
| 27 |
+
"""Generates a Bach-style MIDI chorale from a random seed sequence using the trained model."""
|
| 28 |
+
sample_seed = np.loadtxt(sample_seed_path, skiprows=1, delimiter=",").flatten()[sample_seed_rows].reshape(1,-1)
|
| 29 |
+
sample_seed = note2id(sample_seed)
|
| 30 |
+
chorale = generate_sequence(sample_seed, model, include_init_context=include_init_context,
|
| 31 |
+
max_len=max_len, temperature=temperature ,seed=seed)
|
| 32 |
+
chorale = keras.ops.convert_to_numpy(keras.ops.reshape(id2note(chorale), (-1,4)))
|
| 33 |
+
strm = stream.Stream([chord.Chord(chorale[s].tolist()) for s in range(len(chorale))])
|
| 34 |
+
strm.write('midi', fp=file_name)
|
| 35 |
+
print(f"chorale saved as {file_name}")
|
| 36 |
+
|
| 37 |
+
def draw_random_sample(csv_dir, seed=42):
|
| 38 |
+
"""Selects and returns a random CSV file path from the given directory for sampling."""
|
| 39 |
+
files = glob.glob(csv_dir + '/*.csv')
|
| 40 |
+
random.seed(seed)
|
| 41 |
+
random.shuffle(files)
|
| 42 |
+
return files[0]
|
src/metrics.py
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import keras
|
| 2 |
+
|
| 3 |
+
class Preplexity(keras.metrics.Metric):
|
| 4 |
+
"""Custom Keras metric that measures model uncertainty by exponentiating average cross-entropy loss."""
|
| 5 |
+
def __init__(self, name="Preplexity", **kwargs):
|
| 6 |
+
super().__init__(name=name, **kwargs)
|
| 7 |
+
self.cross_entropy = keras.metrics.Mean()
|
| 8 |
+
|
| 9 |
+
def update_state(self, y_true, y_pred, sample_weight=None):
|
| 10 |
+
"""expects y_pred to be logits"""
|
| 11 |
+
ce = keras.losses.sparse_categorical_crossentropy(y_true, y_pred , from_logits=True)
|
| 12 |
+
# mean over batch and seq_len dimmensions
|
| 13 |
+
self.cross_entropy.update_state(ce, sample_weight=sample_weight)
|
| 14 |
+
|
| 15 |
+
def result(self):
|
| 16 |
+
return keras.ops.exp(self.cross_entropy.result())
|
| 17 |
+
|
| 18 |
+
def reset_state(self):
|
| 19 |
+
self.cross_entropy.reset_state()
|
src/model.py
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import keras
|
| 2 |
+
from metrics import Preplexity
|
| 3 |
+
|
| 4 |
+
def get_model(lr, weight_decay, emb_in, emb_out, lstm_layers, lstm_units, lstm_dropout, dense_units, dropout):
|
| 5 |
+
"""Constructs and compiles the multi-layer LSTM model for next-note prediction with embedding, dropout, and normalization."""
|
| 6 |
+
assert lstm_layers >= 1, "expect at least one LSTM layer"
|
| 7 |
+
|
| 8 |
+
model = keras.Sequential([], name="BachModel")
|
| 9 |
+
model.add(keras.layers.Embedding(emb_in ,emb_out, name="Embedding_Layer", input_shape=(None,)))
|
| 10 |
+
|
| 11 |
+
for layer in range(lstm_layers):
|
| 12 |
+
model.add(keras.layers.LSTM(lstm_units, return_sequences= True, dropout= lstm_dropout, name=f"LSTM_Layer_{layer}"))
|
| 13 |
+
model.add(keras.layers.LayerNormalization(name=f"Layer_Norm_{layer}"))
|
| 14 |
+
|
| 15 |
+
if dense_units > 0:
|
| 16 |
+
model.add(keras.layers.Dense(dense_units, activation="relu", name="Dense_Layer",
|
| 17 |
+
kernel_regularizer=keras.regularizers.L2(1e-5)))
|
| 18 |
+
model.add(keras.layers.Dropout(dropout, name="Dropout_Layer"))
|
| 19 |
+
|
| 20 |
+
model.add(keras.layers.Dense(emb_in, name="Logits"))
|
| 21 |
+
|
| 22 |
+
model.compile(
|
| 23 |
+
loss= keras.losses.SparseCategoricalCrossentropy(from_logits=True),
|
| 24 |
+
optimizer= keras.optimizers.Nadam(lr, weight_decay = weight_decay, clipnorm=1.0),
|
| 25 |
+
metrics= [Preplexity(), "accuracy"] , jit_compile=True
|
| 26 |
+
)
|
| 27 |
+
|
| 28 |
+
return model
|
src/trainer.py
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import keras
|
| 2 |
+
import json
|
| 3 |
+
import os
|
| 4 |
+
|
| 5 |
+
def exp_decay(epoch, lr):
|
| 6 |
+
return lr * 0.1 ** (epoch / 40)
|
| 7 |
+
|
| 8 |
+
def train_model(bach_model, train, val, n_epochs, ARTIFACTS_PATH, MODEL_PATH):
|
| 9 |
+
callbacks = [
|
| 10 |
+
keras.callbacks.LearningRateScheduler(exp_decay),
|
| 11 |
+
keras.callbacks.EarlyStopping(patience= 3, restore_best_weights= False, verbose= True, min_delta= 5e-5),
|
| 12 |
+
keras.callbacks.ModelCheckpoint(os.path.join(ARTIFACTS_PATH , "checkpoint.keras"), verbose= 1),
|
| 13 |
+
]
|
| 14 |
+
|
| 15 |
+
train_logs = bach_model.fit(train, validation_data= val, epochs= n_epochs, callbacks= callbacks)
|
| 16 |
+
|
| 17 |
+
bach_model.save(os.path.join(MODEL_PATH, "bach_model.keras"))
|
| 18 |
+
|
| 19 |
+
with open(os.path.join(ARTIFACTS_PATH, "train_logs.json"), "w") as f:
|
| 20 |
+
json.dump(train_logs.history, f, indent=4)
|
src/utils.py
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import keras
|
| 2 |
+
import os
|
| 3 |
+
import subprocess
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
|
| 6 |
+
def get_dataset_path(root_dir, URL):
|
| 7 |
+
"""Downlods chorales csv dataset and confgirues files path"""
|
| 8 |
+
DATASET_PATH = keras.utils.get_file(
|
| 9 |
+
"jsb_chorales.zip",
|
| 10 |
+
URL,
|
| 11 |
+
extract= True,
|
| 12 |
+
cache_dir= root_dir,
|
| 13 |
+
cache_subdir= "data"
|
| 14 |
+
)
|
| 15 |
+
|
| 16 |
+
TRAIN_PATH = os.path.join(DATASET_PATH, "jsb_chorales/train")
|
| 17 |
+
VAL_PATH = os.path.join(DATASET_PATH, "jsb_chorales/val")
|
| 18 |
+
ARTIFACTS_PATH = os.path.join(root_dir, "artifacts")
|
| 19 |
+
MODEL_PATH = os.path.join(root_dir, "model")
|
| 20 |
+
|
| 21 |
+
os.makedirs(ARTIFACTS_PATH, exist_ok=True)
|
| 22 |
+
os.makedirs(MODEL_PATH, exist_ok=True)
|
| 23 |
+
return TRAIN_PATH, VAL_PATH, ARTIFACTS_PATH, MODEL_PATH
|
| 24 |
+
|
| 25 |
+
def midi_to_wave(midi_file_path, SF2_PATH, wave_path="samples/sample.wav"):
|
| 26 |
+
"""Converts a MIDI file to a WAV audio file using FluidSynth."""
|
| 27 |
+
if not os.path.exists(midi_file_path):
|
| 28 |
+
raise FileNotFoundError(f"MIDI file not found: {midi_file_path}")
|
| 29 |
+
if not os.path.exists(SF2_PATH):
|
| 30 |
+
raise FileNotFoundError(f"SoundFont file not found: {SF2_PATH}")
|
| 31 |
+
|
| 32 |
+
os.makedirs(os.path.dirname(wave_path), exist_ok=True)
|
| 33 |
+
cmd = ["fluidsynth", "-ni", "-F", wave_path, "-r", "44100", SF2_PATH, midi_file_path]
|
| 34 |
+
|
| 35 |
+
try:
|
| 36 |
+
subprocess.run(cmd, check=True, capture_output=True, text=True)
|
| 37 |
+
except subprocess.CalledProcessError as e:
|
| 38 |
+
raise RuntimeError(f"FluidSynth failed: {e.stderr}")
|
| 39 |
+
|
| 40 |
+
print(f"WAV file saved at {wave_path}")
|
| 41 |
+
|
| 42 |
+
ASSETS_DIR = Path(__file__).parent.parent / "assets"
|
| 43 |
+
|
| 44 |
+
def load_css():
|
| 45 |
+
return (ASSETS_DIR / "css/theme.css").read_text(encoding="utf-8")
|
| 46 |
+
|
| 47 |
+
def load_markdown(name):
|
| 48 |
+
return (ASSETS_DIR / f"markdown/{name}.md").read_text(encoding="utf-8")
|