Fixed dim_z to 120, modified path to be dynamic and cleaned VAE constructor of unnecessary parameters
Browse files
model.py
CHANGED
|
@@ -1,4 +1,6 @@
|
|
| 1 |
import tensorflow as tf
|
|
|
|
|
|
|
| 2 |
|
| 3 |
|
| 4 |
_CAP = 3501 # Cap for the number of notes
|
|
@@ -112,16 +114,24 @@ class VAECost:
|
|
| 112 |
|
| 113 |
class VAE(tf.keras.Model):
|
| 114 |
|
| 115 |
-
def __init__(self,
|
| 116 |
super(VAE, self).__init__(name=name, **kwargs)
|
| 117 |
self.dim_x = (3, _CAP, 1)
|
| 118 |
-
self.
|
| 119 |
-
self.
|
| 120 |
-
self.analytic_kl = analytic_kl
|
| 121 |
-
self.encoder = Encoder_Z(dim_z=self.dim_z).build()
|
| 122 |
-
self.decoder = Decoder_X(dim_z=self.dim_z).build()
|
| 123 |
self.cost_func = VAECost(self)
|
| 124 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 125 |
|
| 126 |
@tf.function()
|
| 127 |
def train_step(self, data):
|
|
@@ -143,12 +153,12 @@ class VAE(tf.keras.Model):
|
|
| 143 |
|
| 144 |
mu, rho = tf.split(self.encoder(x_input), num_or_size_splits=2, axis=1)
|
| 145 |
sd = tf.math.log(1 + tf.math.exp(rho))
|
| 146 |
-
z_sample = mu + sd * tf.random.normal(shape=(
|
| 147 |
return z_sample, mu, sd
|
| 148 |
|
| 149 |
def generate(self, z_sample=None):
|
| 150 |
# Decode a latent representation of a song, which is provided or sampled
|
| 151 |
|
| 152 |
if z_sample == None:
|
| 153 |
-
z_sample = tf.expand_dims(tf.random.normal(shape=(
|
| 154 |
return self.decoder(z_sample)
|
|
|
|
| 1 |
import tensorflow as tf
|
| 2 |
+
import os
|
| 3 |
+
import inspect
|
| 4 |
|
| 5 |
|
| 6 |
_CAP = 3501 # Cap for the number of notes
|
|
|
|
| 114 |
|
| 115 |
class VAE(tf.keras.Model):
|
| 116 |
|
| 117 |
+
def __init__(self, **kwargs):
|
| 118 |
super(VAE, self).__init__(name=name, **kwargs)
|
| 119 |
self.dim_x = (3, _CAP, 1)
|
| 120 |
+
self.encoder = Encoder_Z(dim_z=120).build()
|
| 121 |
+
self.decoder = Decoder_X(dim_z=120).build()
|
|
|
|
|
|
|
|
|
|
| 122 |
self.cost_func = VAECost(self)
|
| 123 |
+
|
| 124 |
+
# Get the path of the script that defines this method
|
| 125 |
+
script_path = inspect.getfile(inspect.currentframe())
|
| 126 |
+
|
| 127 |
+
# Get the directory containing the script
|
| 128 |
+
script_dir = os.path.dirname(os.path.abspath(script_path))
|
| 129 |
+
|
| 130 |
+
# Construct the path to the weights folder
|
| 131 |
+
weights_dir = os.path.join(script_dir, 'weights') + os.sep
|
| 132 |
+
|
| 133 |
+
# Load pretrained weights
|
| 134 |
+
self.load_weights(weights_dir)
|
| 135 |
|
| 136 |
@tf.function()
|
| 137 |
def train_step(self, data):
|
|
|
|
| 153 |
|
| 154 |
mu, rho = tf.split(self.encoder(x_input), num_or_size_splits=2, axis=1)
|
| 155 |
sd = tf.math.log(1 + tf.math.exp(rho))
|
| 156 |
+
z_sample = mu + sd * tf.random.normal(shape=(120,))
|
| 157 |
return z_sample, mu, sd
|
| 158 |
|
| 159 |
def generate(self, z_sample=None):
|
| 160 |
# Decode a latent representation of a song, which is provided or sampled
|
| 161 |
|
| 162 |
if z_sample == None:
|
| 163 |
+
z_sample = tf.expand_dims(tf.random.normal(shape=(120,)), axis=0)
|
| 164 |
return self.decoder(z_sample)
|