text2speech / train.py
img-gemina's picture
Create train.py
30d1025 verified
import tensorflow as tf
import numpy as np
from tensorflow.keras.layers import *
import os
from datetime import datetime
import json
class PositionalEncoding(Layer):
def __init__(self, position, d_model):
super(PositionalEncoding, self).__init__()
self.pos_encoding = self.positional_encoding(position, d_model)
def get_angles(self, position, i, d_model):
angles = 1 / np.power(10000, (2 * (i//2)) / np.float32(d_model))
return position * angles
def positional_encoding(self, position, d_model):
angle_rads = self.get_angles(
position=np.arange(position)[:, np.newaxis],
i=np.arange(d_model)[np.newaxis, :],
d_model=d_model)
sines = np.sin(angle_rads[:, 0::2])
cosines = np.cos(angle_rads[:, 1::2])
pos_encoding = np.concatenate([sines, cosines], axis=-1)
pos_encoding = pos_encoding[np.newaxis, ...]
return tf.cast(pos_encoding, dtype=tf.float32)
def call(self, inputs):
return inputs + self.pos_encoding[:, :tf.shape(inputs)[1], :]
class MultiHeadAttention(Layer):
def __init__(self, d_model, num_heads):
super(MultiHeadAttention, self).__init__()
self.num_heads = num_heads
self.d_model = d_model
assert d_model % self.num_heads == 0
self.depth = d_model // self.num_heads
self.wq = Dense(d_model)
self.wk = Dense(d_model)
self.wv = Dense(d_model)
self.dense = Dense(d_model)
def split_heads(self, x, batch_size):
x = tf.reshape(x, (batch_size, -1, self.num_heads, self.depth))
return tf.transpose(x, perm=[0, 2, 1, 3])
def call(self, v, k, q, mask=None):
batch_size = tf.shape(q)[0]
q = self.wq(q)
k = self.wk(k)
v = self.wv(v)
q = self.split_heads(q, batch_size)
k = self.split_heads(k, batch_size)
v = self.split_heads(v, batch_size)
matmul_qk = tf.matmul(q, k, transpose_b=True)
dk = tf.cast(tf.shape(k)[-1], tf.float32)
scaled_attention_logits = matmul_qk / tf.math.sqrt(dk)
if mask is not None:
scaled_attention_logits += (mask * -1e9)
attention_weights = tf.nn.softmax(scaled_attention_logits, axis=-1)
output = tf.matmul(attention_weights, v)
output = tf.transpose(output, perm=[0, 2, 1, 3])
concat_attention = tf.reshape(output, (batch_size, -1, self.d_model))
output = self.dense(concat_attention)
return output
class TransformerBlock(Layer):
def __init__(self, d_model, num_heads, dff, rate=0.1):
super(TransformerBlock, self).__init__()
self.mha = MultiHeadAttention(d_model, num_heads)
self.ffn = tf.keras.Sequential([
Dense(dff, activation='relu'),
Dense(d_model)
])
self.layernorm1 = LayerNormalization(epsilon=1e-6)
self.layernorm2 = LayerNormalization(epsilon=1e-6)
self.dropout1 = Dropout(rate)
self.dropout2 = Dropout(rate)
def call(self, x, training=False, mask=None):
attn_output = self.mha(x, x, x, mask)
attn_output = self.dropout1(attn_output, training=training)
out1 = self.layernorm1(x + attn_output)
ffn_output = self.ffn(out1)
ffn_output = self.dropout2(ffn_output, training=training)
out2 = self.layernorm2(out1 + ffn_output)
return out2
class TextToSpeechTransformer(tf.keras.Model):
def __init__(self, num_layers, d_model, num_heads, dff, input_vocab_size,
maximum_position_encoding, rate=0.1):
super(TextToSpeechTransformer, self).__init__()
self.embedding = Embedding(input_vocab_size, d_model)
self.pos_encoding = PositionalEncoding(maximum_position_encoding, d_model)
self.dropout = Dropout(rate)
self.transformer_blocks = [
TransformerBlock(d_model, num_heads, dff, rate)
for _ in range(num_layers)
]
self.final_layer = Dense(80)
def call(self, x, training=False, mask=None):
x = self.embedding(x)
x = self.pos_encoding(x)
x = self.dropout(x, training=training)
for transformer_block in self.transformer_blocks:
x = transformer_block(x, training=training, mask=mask)
return self.final_layer(x)
class TTSTrainer:
def __init__(self, model_params, training_params):
self.model_params = model_params
self.training_params = training_params
self.model = self._build_model()
self.timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
self.checkpoint_dir = f"checkpoints/{self.timestamp}"
os.makedirs(self.checkpoint_dir, exist_ok=True)
def _build_model(self):
model = TextToSpeechTransformer(**self.model_params)
optimizer = tf.keras.optimizers.Adam(
learning_rate=self.training_params['learning_rate']
)
model.compile(
optimizer=optimizer,
loss=tf.keras.losses.Huber(delta=1.0),
metrics=['mae']
)
return model
def _create_dataset(self, texts, mels, batch_size):
dataset = tf.data.Dataset.from_tensor_slices((texts, mels))
dataset = dataset.cache()
dataset = dataset.shuffle(10000)
dataset = dataset.batch(batch_size)
dataset = dataset.prefetch(tf.data.AUTOTUNE)
return dataset
def train(self, texts, mels):
train_size = int(0.9 * len(texts))
train_texts, val_texts = texts[:train_size], texts[train_size:]
train_mels, val_mels = mels[:train_size], mels[train_size:]
train_dataset = self._create_dataset(
train_texts, train_mels, self.training_params['batch_size']
)
val_dataset = self._create_dataset(
val_texts, val_mels, self.training_params['batch_size']
)
checkpoint_path = f"{self.checkpoint_dir}/model"
os.makedirs(checkpoint_path, exist_ok=True)
callbacks = [
tf.keras.callbacks.ModelCheckpoint(
filepath=checkpoint_path,
save_weights_only=True,
save_best_only=True,
monitor='val_loss'
),
tf.keras.callbacks.EarlyStopping(
monitor='val_loss',
patience=5,
restore_best_weights=True
),
tf.keras.callbacks.ReduceLROnPlateau(
monitor='val_loss',
factor=0.5,
patience=2
),
tf.keras.callbacks.TensorBoard(
log_dir=f"{self.checkpoint_dir}/logs"
)
]
history = self.model.fit(
train_dataset,
validation_data=val_dataset,
epochs=self.training_params['epochs'],
callbacks=callbacks
)
self._save_model_and_config()
return history
def _save_model_and_config(self):
config = {
'model_params': self.model_params,
'training_params': self.training_params
}
config_path = f"{self.checkpoint_dir}/config.json"
with open(config_path, 'w') as f:
json.dump(config, f)
weights_path = f"{self.checkpoint_dir}/model_weights"
self.model.save_weights(weights_path)
tf.saved_model.save(self.model, f"{self.checkpoint_dir}/saved_model")
def load_model(self, checkpoint_dir):
config_path = f"{checkpoint_dir}/config.json"
with open(config_path, 'r') as f:
config = json.load(f)
self.model = self._build_model()
weights_path = f"{checkpoint_dir}/model_weights"
self.model.load_weights(weights_path)
if __name__ == "__main__":
model_params = {
'num_layers': 6,
'd_model': 256,
'num_heads': 8,
'dff': 1024,
'input_vocab_size': 1000,
'maximum_position_encoding': 2048,
'rate': 0.1
}
training_params = {
'batch_size': 32,
'epochs': 100,
'learning_rate': 0.001
}
trainer = TTSTrainer(model_params, training_params)
# Generate some dummy data for testing
input_texts = np.random.randint(0, 1000, size=(1000, 100))
target_mels = np.random.uniform(size=(1000, 100, 80))
history = trainer.train(input_texts, target_mels)