Spaces:
Sleeping
Sleeping
v0
Browse files- .gitignore +4 -0
- Loss_per_epoch.png +0 -0
- app.py +46 -0
- app.txt +3 -0
- checkpoint/ckpt_mdl_lstm_ep_100_hsize_150_dout_0.t0 +0 -0
- checkpoint/ckpt_mdl_lstm_ep_100_hsize_150_dout_0.t20 +0 -0
- checkpoint/ckpt_mdl_lstm_ep_100_hsize_150_dout_0.t40 +0 -0
- checkpoint/ckpt_mdl_lstm_ep_100_hsize_150_dout_0.t60 +0 -0
- checkpoint/ckpt_mdl_lstm_ep_100_hsize_150_dout_0.t80 +0 -0
- checkpoint/ckpt_mdl_lstm_ep_100_hsize_150_dout_0.t99 +0 -0
- checkpoint/model.pth +0 -0
- config/vocab.json +1 -0
- convert.py +33 -0
- data/.DS_Store +0 -0
- data/music.txt +0 -0
- data/pop.txt +0 -0
- data/sample-music.txt +25 -0
- model.py +66 -0
- requirments.txt +3 -0
- train.py +294 -0
- utils.py +91 -0
.gitignore
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
__pycache__
|
| 2 |
+
output/*
|
| 3 |
+
temp.ipynb
|
| 4 |
+
output*
|
Loss_per_epoch.png
ADDED
|
app.py
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import numpy as np
|
| 4 |
+
import gradio as gr
|
| 5 |
+
from model import MusicLSTM
|
| 6 |
+
from train import DataLoader, Config, generate_song as generate_ABC_notation
|
| 7 |
+
from utils import load_vocab
|
| 8 |
+
from convert import abc_to_audio
|
| 9 |
+
|
| 10 |
+
class GradioApp():
|
| 11 |
+
def __init__(self):
|
| 12 |
+
# Set up configuration and data
|
| 13 |
+
self.config = Config()
|
| 14 |
+
self.CHECKPOINT_FILE = "checkpoint/model.pth"
|
| 15 |
+
self.data_loader = DataLoader(self.config.INPUT_FILE, self.config)
|
| 16 |
+
self.checkpoint = torch.load(self.CHECKPOINT_FILE, weights_only=False)
|
| 17 |
+
char_idx, char_list = load_vocab()
|
| 18 |
+
self.model = MusicLSTM(
|
| 19 |
+
input_size=len(char_idx),
|
| 20 |
+
hidden_size=self.config.HIDDEN_SIZE,
|
| 21 |
+
output_size=len(char_idx),
|
| 22 |
+
)
|
| 23 |
+
self.model.load_state_dict(self.checkpoint)
|
| 24 |
+
self.model.eval()
|
| 25 |
+
|
| 26 |
+
#Setup Interface
|
| 27 |
+
self.input = gr.Button("")
|
| 28 |
+
self.output = gr.Audio(label="Generated Music")
|
| 29 |
+
# self.output = gr.Textbox("")
|
| 30 |
+
self.interface = gr.Interface(fn=self.generate_music, inputs=self.input, outputs=self.output, title="AI Music Generator", description="Generate a new song using a trained RNN model.")
|
| 31 |
+
|
| 32 |
+
def launch(self):
|
| 33 |
+
self.interface.launch()
|
| 34 |
+
|
| 35 |
+
def generate_music(self, input):
|
| 36 |
+
"""Generate a new song using the trained model."""
|
| 37 |
+
abc_notation = generate_ABC_notation(self.model, self.data_loader)
|
| 38 |
+
abc_notation = abc_notation.strip("<start>").strip("<end>").strip()
|
| 39 |
+
audio = abc_to_audio(abc_notation)
|
| 40 |
+
return audio
|
| 41 |
+
|
| 42 |
+
if __name__ == '__main__':
|
| 43 |
+
app = GradioApp()
|
| 44 |
+
app.launch()
|
| 45 |
+
|
| 46 |
+
|
app.txt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
libfluidsynth-dev
|
| 2 |
+
libsndfile1
|
| 3 |
+
abc2midi
|
checkpoint/ckpt_mdl_lstm_ep_100_hsize_150_dout_0.t0
ADDED
|
Binary file (844 kB). View file
|
|
|
checkpoint/ckpt_mdl_lstm_ep_100_hsize_150_dout_0.t20
ADDED
|
Binary file (844 kB). View file
|
|
|
checkpoint/ckpt_mdl_lstm_ep_100_hsize_150_dout_0.t40
ADDED
|
Binary file (844 kB). View file
|
|
|
checkpoint/ckpt_mdl_lstm_ep_100_hsize_150_dout_0.t60
ADDED
|
Binary file (844 kB). View file
|
|
|
checkpoint/ckpt_mdl_lstm_ep_100_hsize_150_dout_0.t80
ADDED
|
Binary file (844 kB). View file
|
|
|
checkpoint/ckpt_mdl_lstm_ep_100_hsize_150_dout_0.t99
ADDED
|
Binary file (844 kB). View file
|
|
|
checkpoint/model.pth
ADDED
|
Binary file (839 kB). View file
|
|
|
config/vocab.json
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
{"char_idx": "<]=_4Xl)uq5CBw#d(~H}scntZ!hIF6p'\\E/g&?fTW{^-v9MA710+oK\tJS[\n,Q \"G2a:L|mxVNbPRk*jYyD3e.8Oi>Uzr@", "char_list": ["<", "]", "=", "_", "4", "X", "l", ")", "u", "q", "5", "C", "B", "w", "#", "d", "(", "~", "H", "}", "s", "c", "n", "t", "Z", "!", "h", "I", "F", "6", "p", "'", "\\", "E", "/", "g", "&", "?", "f", "T", "W", "{", "^", "-", "v", "9", "M", "A", "7", "1", "0", "+", "o", "K", "\t", "J", "S", "[", "\n", ",", "Q", " ", "\"", "G", "2", "a", ":", "L", "|", "m", "x", "V", "N", "b", "P", "R", "k", "*", "j", "Y", "y", "D", "3", "e", ".", "8", "O", "i", ">", "U", "z", "r", "@"]}
|
convert.py
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from music21 import converter, stream
|
| 2 |
+
from midi2audio import FluidSynth
|
| 3 |
+
import subprocess
|
| 4 |
+
|
| 5 |
+
def abc_to_audio(abc_notation, output_format='wav',sound_font="FluidR3_GM.sf2"):
|
| 6 |
+
""" Convert ABC notation to wav file. """
|
| 7 |
+
abc_file = 'output.abc'
|
| 8 |
+
with open(abc_file, 'w') as f:
|
| 9 |
+
f.write(abc_notation)
|
| 10 |
+
subprocess.run(['abc2midi', abc_file, '-o', "output.midi"])
|
| 11 |
+
fs = FluidSynth()
|
| 12 |
+
fs.midi_to_audio("output.midi", "output.wav")
|
| 13 |
+
return "output.wav"
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
if __name__ == '__main__':
|
| 17 |
+
abc_to_audio("""X:12
|
| 18 |
+
T:Byrne: Triop
|
| 19 |
+
C:Trad Figne
|
| 20 |
+
Z:id:hn-hornpipe-53
|
| 21 |
+
M:C|
|
| 22 |
+
K:G
|
| 23 |
+
(3DFB d2dc | def2 edef | e2a2 df | g4- gdBG | A4G | A4 :|
|
| 24 |
+
|: ae edc | edcB A2B2 | A2G2 | G6 d2 | e4^c4 | d4 d4 | ed e2 | d4 ||
|
| 25 |
+
P:variations:
|
| 26 |
+
|: ABA AGE|F2A d2A|d2g d2:|
|
| 27 |
+
a2f fef aba|a2f g2e fed|c2A GBd|f2g g2a|bgb aag|dcB B2G|A2G A2G:|
|
| 28 |
+
|:F2A A2G|AGE G2d||
|
| 29 |
+
P:variations
|
| 30 |
+
|: AGF GBd | cde d2B | c2c c2A :|
|
| 31 |
+
|: de fe | fdfe dFAd | A2AG A2f2 | g2ag e2B2 | A2AB ^cdce | d2d>c | B4z2 | B4 | A4G2 | ^F4G4 | G4 :|
|
| 32 |
+
|: G^F G2 | c4 ||
|
| 33 |
+
GBdB | c2 ded2 | c2B2c2 | d2c2B2 | c2d2 | c2B2 | A4 :|""")
|
data/.DS_Store
ADDED
|
Binary file (30.7 kB). View file
|
|
|
data/music.txt
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
data/pop.txt
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
data/sample-music.txt
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
X:3
|
| 2 |
+
T:Badine
|
| 3 |
+
O:France
|
| 4 |
+
A:Provence
|
| 5 |
+
Z:Transcrit et/ou corrig? par Michel BELLON - 2005-04-01
|
| 6 |
+
Z:Pour toute observation mailto:galouvielle@free.fr
|
| 7 |
+
M:C|
|
| 8 |
+
L:1/8
|
| 9 |
+
Q: "Allegro"
|
| 10 |
+
K:Bb
|
| 11 |
+
V:1 name=G
|
| 12 |
+
d2 (cB) | d2 (cB) f2 ed | f4 g2 g2 | feed eddc | d2B2 d2cB | d2cB f2ed |
|
| 13 |
+
f4 g2g2 | fedc d2!+!c2 | B4 :: FBcB | B2AB cd ec | d2 B2 df dB |
|
| 14 |
+
cf cA Bd cB | B2 A2 f2 f2 | (ABcd) e2 d2 | e2 dc dcde | fefg fefg |
|
| 15 |
+
Te4 d2cB | d2cB f2ed | f4 g2g2 | feed eddc | d2B2 d2cB | d2cB f2ed |
|
| 16 |
+
f4 g2g2 | fedc d2!+!c2 | B4 !fine! :: [K:Bbm] c2de | d2ef B2cd |
|
| 17 |
+
c2F2 dc Bc | !+!=A2 B2 c2 d2 | d2 c2 d2ef | g2g2 c2de | f2f2 B=ABc | F2Bc d2c2 |
|
| 18 |
+
B4:| fefg | f2e=d e2f2 | {f2}g4 edef | e2dc d2e2 | {e2}f4 Bc=Ac |
|
| 19 |
+
B2F2 dece | d2c2 dcde | f6 ed |c4 !D.C.! |]
|
| 20 |
+
V:2 name=V
|
| 21 |
+
z4 | z4 d2cB | d2cB e4 | B4 f2f2 | b2B2 z4 | z4 d2cB | d2cB e4 | d2B2 f2F2 | B4 ::
|
| 22 |
+
z4 | f4f2f2 | B2b2 b2b2 | a2f2 g2=e2 | f2F2 z4 | f2f2 A2B2 | c2f2 B2B2 | B2b2 b2b2 |
|
| 23 |
+
a4 b2B2 | z4 d2cB | d2cB e4 | B4 f2f2 | b2B2 z4 | z4 d2cB | d2cB e4 | d2B2 f2F2 | B4 ::
|
| 24 |
+
[K:Bbm] b2b2 | b4 g2e2 | f2F2 B2B2 | e2d2 c2B2 | f4 b2b2 | e2fg a2a2 | d4 g2g2 | fede f2F2 | B4 :|
|
| 25 |
+
b2b2 | b3a g2f2 | e4 a2a2 | a3g f2e2 | d4 z4 | z4 (bc')(=ac') | b2f2 B2Bc | dcde d2e2 | f4 |]
|
model.py
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
import torch.optim as optim
|
| 5 |
+
from torch.autograd import Variable
|
| 6 |
+
|
| 7 |
+
class MusicLSTM(nn.Module):
|
| 8 |
+
def __init__(self, input_size, hidden_size, output_size, model='lstm', num_layers=1, dropout_p=0):
|
| 9 |
+
super(MusicLSTM, self).__init__()
|
| 10 |
+
self.model = model
|
| 11 |
+
self.input_size = input_size
|
| 12 |
+
self.hidden_size = hidden_size
|
| 13 |
+
self.output_size = output_size
|
| 14 |
+
self.num_layers = num_layers
|
| 15 |
+
|
| 16 |
+
self.embeddings = nn.Embedding(input_size, hidden_size)
|
| 17 |
+
if self.model == 'lstm':
|
| 18 |
+
self.rnn = nn.LSTM(hidden_size, hidden_size, num_layers)
|
| 19 |
+
elif self.model == 'gru':
|
| 20 |
+
self.rnn = nn.GRU(hidden_size, hidden_size, num_layers)
|
| 21 |
+
else:
|
| 22 |
+
raise NotImplementedError
|
| 23 |
+
|
| 24 |
+
self.out = nn.Linear(self.hidden_size, self.output_size)
|
| 25 |
+
self.drop = nn.Dropout(p=dropout_p)
|
| 26 |
+
|
| 27 |
+
def init_hidden(self, batch_size=1):
|
| 28 |
+
"""Initialize hidden states."""
|
| 29 |
+
if self.model == 'lstm':
|
| 30 |
+
self.hidden = (
|
| 31 |
+
torch.zeros(self.num_layers, batch_size, self.hidden_size),
|
| 32 |
+
torch.zeros(self.num_layers, batch_size, self.hidden_size)
|
| 33 |
+
)
|
| 34 |
+
elif self.model == 'gru':
|
| 35 |
+
self.hidden = torch.zeros(self.num_layers, batch_size, self.hidden_size)
|
| 36 |
+
|
| 37 |
+
return self.hidden
|
| 38 |
+
|
| 39 |
+
def forward(self, x):
|
| 40 |
+
"""Forward pass."""
|
| 41 |
+
# Ensure x is 2D (sequence length, batch size)
|
| 42 |
+
if x.dim() > 2:
|
| 43 |
+
x = x.squeeze()
|
| 44 |
+
|
| 45 |
+
batch_size = 1 if x.dim() == 1 else x.size(0)
|
| 46 |
+
x = x.long()
|
| 47 |
+
|
| 48 |
+
# Embed the input
|
| 49 |
+
embeds = self.embeddings(x)
|
| 50 |
+
|
| 51 |
+
# Initialize hidden state if not already done
|
| 52 |
+
if not hasattr(self, 'hidden'):
|
| 53 |
+
self.init_hidden(batch_size)
|
| 54 |
+
|
| 55 |
+
# Ensure embeds is 3D for RNN input (sequence length, batch size, embedding size)
|
| 56 |
+
if embeds.dim() == 2:
|
| 57 |
+
embeds = embeds.unsqueeze(1)
|
| 58 |
+
|
| 59 |
+
# RNN processing
|
| 60 |
+
rnn_out, self.hidden = self.rnn(embeds, self.hidden)
|
| 61 |
+
|
| 62 |
+
# Dropout and output layer
|
| 63 |
+
rnn_out = self.drop(rnn_out.squeeze(1))
|
| 64 |
+
output = self.out(rnn_out)
|
| 65 |
+
|
| 66 |
+
return output
|
requirments.txt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
music21
|
| 2 |
+
midi2audio
|
| 3 |
+
pyfluidsynth
|
train.py
ADDED
|
@@ -0,0 +1,294 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
import time
|
| 4 |
+
import random
|
| 5 |
+
import json
|
| 6 |
+
import numpy as np
|
| 7 |
+
import matplotlib.pyplot as plt
|
| 8 |
+
from model import MusicLSTM as MusicRNN
|
| 9 |
+
import torch
|
| 10 |
+
import torch.nn as nn
|
| 11 |
+
import torch.optim as optim
|
| 12 |
+
from torch.autograd import Variable
|
| 13 |
+
from utils import seq_to_tensor, load_vocab, save_vocab
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def logger(active=True):
|
| 17 |
+
"""Simple logging utility."""
|
| 18 |
+
def log(*args, **kwargs):
|
| 19 |
+
if active:
|
| 20 |
+
print(*args, **kwargs)
|
| 21 |
+
return log
|
| 22 |
+
|
| 23 |
+
# Configuration
|
| 24 |
+
class Config:
|
| 25 |
+
SAVE_EVERY = 20
|
| 26 |
+
SEQ_SIZE = 25
|
| 27 |
+
RANDOM_SEED = 11
|
| 28 |
+
VALIDATION_SIZE = 0.15
|
| 29 |
+
LR = 1e-3
|
| 30 |
+
N_EPOCHS = 100
|
| 31 |
+
NUM_LAYERS = 1
|
| 32 |
+
HIDDEN_SIZE = 150
|
| 33 |
+
DROPOUT_P = 0
|
| 34 |
+
MODEL_TYPE = 'lstm'
|
| 35 |
+
INPUT_FILE = 'data/music.txt'
|
| 36 |
+
RESUME = False
|
| 37 |
+
BATCH_SIZE = 1
|
| 38 |
+
|
| 39 |
+
# Utility functions
|
| 40 |
+
def tic():
|
| 41 |
+
"""Start timer."""
|
| 42 |
+
return time.time()
|
| 43 |
+
|
| 44 |
+
def toc(start_time, msg=None):
|
| 45 |
+
"""Calculate elapsed time."""
|
| 46 |
+
s = time.time() - start_time
|
| 47 |
+
m = int(s / 60)
|
| 48 |
+
if msg:
|
| 49 |
+
return f'{m}m {int(s - (m * 60))}s {msg}'
|
| 50 |
+
return f'{m}m {int(s - (m * 60))}s'
|
| 51 |
+
|
| 52 |
+
class DataLoader:
|
| 53 |
+
def __init__(self, input_file, config):
|
| 54 |
+
self.config = config
|
| 55 |
+
self.char_idx, self.char_list = self._load_chars(input_file)
|
| 56 |
+
self.data = self._load_data(input_file)
|
| 57 |
+
self.train_idxs, self.valid_idxs = self._split_data()
|
| 58 |
+
log = logger(True)
|
| 59 |
+
log(f"Total songs: {len(self.data)}")
|
| 60 |
+
log(f"Training songs: {len(self.train_idxs)}")
|
| 61 |
+
log(f"Validation songs: {len(self.valid_idxs)}")
|
| 62 |
+
|
| 63 |
+
def _load_chars(self, input_file):
|
| 64 |
+
"""Load unique characters from the input file."""
|
| 65 |
+
with open(input_file, 'r') as f:
|
| 66 |
+
char_idx = ''.join(set(f.read()))
|
| 67 |
+
return char_idx, list(char_idx)
|
| 68 |
+
|
| 69 |
+
def _load_data(self, input_file):
|
| 70 |
+
"""Load song data from input file."""
|
| 71 |
+
with open(input_file, "r") as f:
|
| 72 |
+
data, buffer = [], ''
|
| 73 |
+
for line in f:
|
| 74 |
+
if line == '<start>\n':
|
| 75 |
+
buffer += line
|
| 76 |
+
elif line == '<end>\n':
|
| 77 |
+
buffer += line
|
| 78 |
+
data.append(buffer)
|
| 79 |
+
buffer = ''
|
| 80 |
+
else:
|
| 81 |
+
buffer += line
|
| 82 |
+
|
| 83 |
+
# Filter songs shorter than sequence size
|
| 84 |
+
data = [song for song in data if len(song) > self.config.SEQ_SIZE + 10]
|
| 85 |
+
return data
|
| 86 |
+
|
| 87 |
+
def _split_data(self):
|
| 88 |
+
"""Split data into training and validation sets."""
|
| 89 |
+
num_train = len(self.data)
|
| 90 |
+
indices = list(range(num_train))
|
| 91 |
+
|
| 92 |
+
np.random.seed(self.config.RANDOM_SEED)
|
| 93 |
+
np.random.shuffle(indices)
|
| 94 |
+
|
| 95 |
+
split_idx = int(np.floor(self.config.VALIDATION_SIZE * num_train))
|
| 96 |
+
train_idxs = indices[split_idx:]
|
| 97 |
+
valid_idxs = indices[:split_idx]
|
| 98 |
+
|
| 99 |
+
return train_idxs, valid_idxs
|
| 100 |
+
|
| 101 |
+
def rand_slice(self, data, slice_len=None):
|
| 102 |
+
"""Get a random slice of data."""
|
| 103 |
+
if slice_len is None:
|
| 104 |
+
slice_len = self.config.SEQ_SIZE
|
| 105 |
+
|
| 106 |
+
d_len = len(data)
|
| 107 |
+
s_idx = random.randint(0, d_len - slice_len)
|
| 108 |
+
e_idx = s_idx + slice_len + 1
|
| 109 |
+
return data[s_idx:e_idx]
|
| 110 |
+
|
| 111 |
+
def seq_to_tensor(self, seq):
|
| 112 |
+
"""Convert sequence to tensor."""
|
| 113 |
+
out = torch.zeros(len(seq)).long()
|
| 114 |
+
for i, c in enumerate(seq):
|
| 115 |
+
out[i] = self.char_idx.index(c)
|
| 116 |
+
return out
|
| 117 |
+
|
| 118 |
+
def song_to_seq_target(self, song):
|
| 119 |
+
"""Convert a song to sequence and target."""
|
| 120 |
+
try:
|
| 121 |
+
a_slice = self.rand_slice(song)
|
| 122 |
+
seq = self.seq_to_tensor(a_slice[:-1])
|
| 123 |
+
target = self.seq_to_tensor(a_slice[1:])
|
| 124 |
+
return seq, target
|
| 125 |
+
except Exception as e:
|
| 126 |
+
print(f"Error in song_to_seq_target: {e}")
|
| 127 |
+
print(f"Song length: {len(song)}")
|
| 128 |
+
raise
|
| 129 |
+
|
| 130 |
+
def train_model(config, data_loader, model, optimizer, loss_function):
|
| 131 |
+
"""Training loop for the model."""
|
| 132 |
+
log = logger(True)
|
| 133 |
+
time_since = tic()
|
| 134 |
+
losses, v_losses = [], []
|
| 135 |
+
|
| 136 |
+
for epoch in range(config.N_EPOCHS):
|
| 137 |
+
# Training phase
|
| 138 |
+
epoch_loss = 0
|
| 139 |
+
model.train()
|
| 140 |
+
|
| 141 |
+
for i, song_idx in enumerate(data_loader.train_idxs):
|
| 142 |
+
try:
|
| 143 |
+
seq, target = data_loader.song_to_seq_target(data_loader.data[song_idx])
|
| 144 |
+
|
| 145 |
+
# Reset hidden state and gradients
|
| 146 |
+
model.init_hidden()
|
| 147 |
+
optimizer.zero_grad()
|
| 148 |
+
|
| 149 |
+
# Forward pass
|
| 150 |
+
outputs = model(seq)
|
| 151 |
+
loss = loss_function(outputs, target)
|
| 152 |
+
|
| 153 |
+
# Backward pass and optimization
|
| 154 |
+
loss.backward()
|
| 155 |
+
optimizer.step()
|
| 156 |
+
|
| 157 |
+
epoch_loss += loss.item()
|
| 158 |
+
|
| 159 |
+
msg = f'\rTraining Epoch: {epoch}, {(i+1)/len(data_loader.train_idxs)*100:.2f}% iter: {i} Time: {toc(time_since)} Loss: {loss.item():.4f}'
|
| 160 |
+
sys.stdout.write(msg)
|
| 161 |
+
sys.stdout.flush()
|
| 162 |
+
|
| 163 |
+
except Exception as e:
|
| 164 |
+
log(f"Error processing song {song_idx}: {e}")
|
| 165 |
+
continue
|
| 166 |
+
|
| 167 |
+
print()
|
| 168 |
+
losses.append(epoch_loss / len(data_loader.train_idxs))
|
| 169 |
+
|
| 170 |
+
# Validation phase
|
| 171 |
+
model.eval()
|
| 172 |
+
val_loss = 0
|
| 173 |
+
with torch.no_grad():
|
| 174 |
+
for i, song_idx in enumerate(data_loader.valid_idxs):
|
| 175 |
+
try:
|
| 176 |
+
seq, target = data_loader.song_to_seq_target(data_loader.data[song_idx])
|
| 177 |
+
|
| 178 |
+
# Reset hidden state
|
| 179 |
+
model.init_hidden()
|
| 180 |
+
|
| 181 |
+
# Forward pass
|
| 182 |
+
outputs = model(seq)
|
| 183 |
+
loss = loss_function(outputs, target)
|
| 184 |
+
|
| 185 |
+
val_loss += loss.item()
|
| 186 |
+
|
| 187 |
+
msg = f'\rValidation Epoch: {epoch}, {(i+1)/len(data_loader.valid_idxs)*100:.2f}% iter: {i} Time: {toc(time_since)} Loss: {loss.item():.4f}'
|
| 188 |
+
sys.stdout.write(msg)
|
| 189 |
+
sys.stdout.flush()
|
| 190 |
+
|
| 191 |
+
except Exception as e:
|
| 192 |
+
log(f"Error processing validation song {song_idx}: {e}")
|
| 193 |
+
continue
|
| 194 |
+
|
| 195 |
+
print()
|
| 196 |
+
v_losses.append(val_loss / len(data_loader.valid_idxs))
|
| 197 |
+
|
| 198 |
+
# Checkpoint saving
|
| 199 |
+
if epoch % config.SAVE_EVERY == 0 or epoch == config.N_EPOCHS - 1:
|
| 200 |
+
log('=======> Saving..')
|
| 201 |
+
state = {
|
| 202 |
+
'model': model.state_dict(),
|
| 203 |
+
'optimizer': optimizer.state_dict(),
|
| 204 |
+
'loss': losses[-1],
|
| 205 |
+
'v_loss': v_losses[-1],
|
| 206 |
+
'losses': losses,
|
| 207 |
+
'v_losses': v_losses,
|
| 208 |
+
'epoch': epoch,
|
| 209 |
+
}
|
| 210 |
+
os.makedirs('checkpoint', exist_ok=True)
|
| 211 |
+
torch.save(model, f'checkpoint/ckpt_mdl_{config.MODEL_TYPE}_ep_{config.N_EPOCHS}_hsize_{config.HIDDEN_SIZE}_dout_{config.DROPOUT_P}.t{epoch}')
|
| 212 |
+
|
| 213 |
+
return losses, v_losses
|
| 214 |
+
|
| 215 |
+
def plot_losses(losses, v_losses):
|
| 216 |
+
"""Plot training and validation losses."""
|
| 217 |
+
plt.figure(figsize=(10, 5))
|
| 218 |
+
plt.plot(losses, label='Training Loss')
|
| 219 |
+
plt.plot(v_losses, label='Validation Loss')
|
| 220 |
+
plt.xlabel('Epoch')
|
| 221 |
+
plt.ylabel('Loss')
|
| 222 |
+
plt.title('Loss per Epoch')
|
| 223 |
+
plt.legend()
|
| 224 |
+
plt.show()
|
| 225 |
+
|
| 226 |
+
def generate_song(model, data_loader, prime_str='<start>', max_len=1000, temp=0.8):
|
| 227 |
+
"""Generate a new song using the trained model."""
|
| 228 |
+
model.eval()
|
| 229 |
+
model.init_hidden()
|
| 230 |
+
creation = prime_str
|
| 231 |
+
char_idx, char_list = load_vocab()
|
| 232 |
+
|
| 233 |
+
# Build up hidden state
|
| 234 |
+
prime = seq_to_tensor(creation, char_idx)
|
| 235 |
+
|
| 236 |
+
with torch.no_grad():
|
| 237 |
+
for _ in range(len(prime)-1):
|
| 238 |
+
_ = model(prime[_:_+1])
|
| 239 |
+
|
| 240 |
+
# Generate rest of sequence
|
| 241 |
+
for _ in range(max_len):
|
| 242 |
+
last_char = prime[-1:]
|
| 243 |
+
out = model(last_char).squeeze()
|
| 244 |
+
|
| 245 |
+
out = torch.exp(out/temp)
|
| 246 |
+
dist = out / torch.sum(out)
|
| 247 |
+
|
| 248 |
+
# Sample from distribution
|
| 249 |
+
next_char_idx = torch.multinomial(dist, 1).item()
|
| 250 |
+
next_char = char_idx[next_char_idx]
|
| 251 |
+
|
| 252 |
+
creation += next_char
|
| 253 |
+
prime = torch.cat([prime, torch.tensor([next_char_idx])], dim=0)
|
| 254 |
+
|
| 255 |
+
if creation[-5:] == '<end>':
|
| 256 |
+
break
|
| 257 |
+
|
| 258 |
+
return creation
|
| 259 |
+
|
| 260 |
+
def main():
|
| 261 |
+
"""Main execution function."""
|
| 262 |
+
# Set up configuration and data
|
| 263 |
+
global model, data_loader
|
| 264 |
+
config = Config()
|
| 265 |
+
data_loader = DataLoader(config.INPUT_FILE, config)
|
| 266 |
+
|
| 267 |
+
# Model setup
|
| 268 |
+
in_size = out_size = len(data_loader.char_idx)
|
| 269 |
+
model = MusicRNN(
|
| 270 |
+
in_size,
|
| 271 |
+
config.HIDDEN_SIZE,
|
| 272 |
+
out_size,
|
| 273 |
+
config.MODEL_TYPE,
|
| 274 |
+
config.NUM_LAYERS,
|
| 275 |
+
config.DROPOUT_P
|
| 276 |
+
)
|
| 277 |
+
|
| 278 |
+
# Optimizer and loss
|
| 279 |
+
optimizer = torch.optim.Adam(model.parameters(), lr=config.LR)
|
| 280 |
+
loss_function = nn.CrossEntropyLoss()
|
| 281 |
+
|
| 282 |
+
# Train the model
|
| 283 |
+
losses, v_losses = train_model(config, data_loader, model, optimizer, loss_function)
|
| 284 |
+
|
| 285 |
+
# Plot losses
|
| 286 |
+
plot_losses(losses, v_losses)
|
| 287 |
+
save_vocab(data_loader)
|
| 288 |
+
|
| 289 |
+
# Generate a song
|
| 290 |
+
generated_song = generate_song(model, data_loader)
|
| 291 |
+
print("Generated Song:", generated_song)
|
| 292 |
+
|
| 293 |
+
if __name__ == "__main__":
|
| 294 |
+
main()
|
utils.py
ADDED
|
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
import time
|
| 4 |
+
import json
|
| 5 |
+
import torch
|
| 6 |
+
|
| 7 |
+
# Only do the function below if verbose
|
| 8 |
+
def logger(verbose):
|
| 9 |
+
def log(*msg):
|
| 10 |
+
if verbose: print(*msg)
|
| 11 |
+
return log
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
last_time = time.time()
|
| 15 |
+
begin_time = last_time
|
| 16 |
+
|
| 17 |
+
def progress_bar(current, total, msg=None):
|
| 18 |
+
global last_time, begin_time
|
| 19 |
+
if current == 0:
|
| 20 |
+
begin_time = time.time() # Reset for new bar.
|
| 21 |
+
cur_time = time.time()
|
| 22 |
+
step_time = cur_time - last_time
|
| 23 |
+
last_time = cur_time
|
| 24 |
+
tot_time = cur_time - begin_time
|
| 25 |
+
|
| 26 |
+
L = []
|
| 27 |
+
L.append(' Step: %s' % format_time(step_time))
|
| 28 |
+
L.append(' | Tot: %s' % format_time(tot_time))
|
| 29 |
+
if msg:
|
| 30 |
+
L.append(' | ' + msg)
|
| 31 |
+
msg = ''.join(L)
|
| 32 |
+
sys.stdout.write(msg)
|
| 33 |
+
sys.stdout.write('\r')
|
| 34 |
+
#if current < total-1:
|
| 35 |
+
#
|
| 36 |
+
#else:
|
| 37 |
+
#sys.stdout.write('\n')
|
| 38 |
+
sys.stdout.flush()
|
| 39 |
+
|
| 40 |
+
def format_time(seconds):
|
| 41 |
+
days = int(seconds / 3600/24)
|
| 42 |
+
seconds = seconds - days*3600*24
|
| 43 |
+
hours = int(seconds / 3600)
|
| 44 |
+
seconds = seconds - hours*3600
|
| 45 |
+
minutes = int(seconds / 60)
|
| 46 |
+
seconds = seconds - minutes*60
|
| 47 |
+
secondsf = int(seconds)
|
| 48 |
+
seconds = seconds - secondsf
|
| 49 |
+
millis = int(seconds*1000)
|
| 50 |
+
|
| 51 |
+
f = ''
|
| 52 |
+
i = 1
|
| 53 |
+
if days > 0:
|
| 54 |
+
f += str(days) + 'D'
|
| 55 |
+
i += 1
|
| 56 |
+
if hours > 0 and i <= 2:
|
| 57 |
+
f += str(hours) + 'h'
|
| 58 |
+
i += 1
|
| 59 |
+
if minutes > 0 and i <= 2:
|
| 60 |
+
f += str(minutes) + 'm'
|
| 61 |
+
i += 1
|
| 62 |
+
if secondsf > 0 and i <= 2:
|
| 63 |
+
f += str(secondsf) + 's'
|
| 64 |
+
i += 1
|
| 65 |
+
if millis > 0 and i <= 2:
|
| 66 |
+
f += str(millis) + 'ms'
|
| 67 |
+
i += 1
|
| 68 |
+
if f == '':
|
| 69 |
+
f = '0ms'
|
| 70 |
+
return f
|
| 71 |
+
|
| 72 |
+
def save_vocab(data_loader, vocab_filename="config/vocab.json"):
|
| 73 |
+
"""Save vocabulary to a JSON file."""
|
| 74 |
+
vocab = {
|
| 75 |
+
'char_idx': data_loader.char_idx,
|
| 76 |
+
'char_list': data_loader.char_list
|
| 77 |
+
}
|
| 78 |
+
with open(vocab_filename, 'w') as f:
|
| 79 |
+
json.dump(vocab, f)
|
| 80 |
+
|
| 81 |
+
def load_vocab(vocab_filename='config/vocab.json'):
|
| 82 |
+
with open(vocab_filename, 'r') as f:
|
| 83 |
+
vocab = json.load(f)
|
| 84 |
+
return vocab['char_idx'], vocab['char_list']
|
| 85 |
+
|
| 86 |
+
def seq_to_tensor(seq, char_idx):
|
| 87 |
+
"""Convert sequence to tensor."""
|
| 88 |
+
out = torch.zeros(len(seq)).long()
|
| 89 |
+
for i, c in enumerate(seq):
|
| 90 |
+
out[i] = char_idx.index(c)
|
| 91 |
+
return out
|