| | import tensorflow as tf |
| | import numpy as np |
| |
|
| | class LocationSensitiveAttention(tf.keras.layers.Layer): |
| | def __init__(self, attn_dim=128, attn_filters=32, attn_kernel=31, **kw): |
| | super().__init__(**kw) |
| | self.attn_dim = attn_dim |
| | self.attn_filters = attn_filters |
| | self.attn_kernel = attn_kernel |
| |
|
| | def build(self, input_shape): |
| | d = self.attn_dim |
| | self.W_query = tf.keras.layers.Dense(d, use_bias=False) |
| | self.W_memory = tf.keras.layers.Dense(d, use_bias=False) |
| | self.loc_conv = tf.keras.layers.Conv1D( |
| | self.attn_filters, self.attn_kernel, padding='same', use_bias=False) |
| | self.W_loc = tf.keras.layers.Dense(d, use_bias=False) |
| | self.v = tf.keras.layers.Dense(1, use_bias=False) |
| | self.b = self.add_weight(name='attn_bias', shape=[d], |
| | initializer='zeros', trainable=True) |
| | super().build(input_shape) |
| |
|
| | def call(self, query, memory, prev_weights): |
| | q = self.W_query(tf.expand_dims(query, 1)) |
| | m = self.W_memory(memory) |
| | loc = self.loc_conv(tf.expand_dims(prev_weights, -1)) |
| | loc = self.W_loc(loc) |
| | e = self.v(tf.nn.tanh(q + m + loc + self.b)) |
| | e = tf.squeeze(e, -1) |
| | w = tf.nn.softmax(e, axis=-1) |
| | ctx = tf.reduce_sum(tf.expand_dims(w, -1) * memory, 1) |
| | return ctx, w |
| |
|
| | class Prenet(tf.keras.layers.Layer): |
| | def __init__(self, units=256, **kw): |
| | super().__init__(**kw) |
| | self.fc1 = tf.keras.layers.Dense(units, activation='relu') |
| | self.fc2 = tf.keras.layers.Dense(units, activation='relu') |
| | self.drop1 = tf.keras.layers.Dropout(0.5) |
| | self.drop2 = tf.keras.layers.Dropout(0.5) |
| |
|
| | def call(self, x, **_): |
| | x = self.drop1(self.fc1(x), training=True) |
| | x = self.drop2(self.fc2(x), training=True) |
| | return x |
| |
|
| | class ConvBN(tf.keras.layers.Layer): |
| | def __init__(self, filters, kernel=5, drop=0.5, **kw): |
| | super().__init__(**kw) |
| | self.conv = tf.keras.layers.Conv1D(filters, kernel, padding='same') |
| | self.bn = tf.keras.layers.BatchNormalization() |
| | self.dropout = tf.keras.layers.Dropout(drop) |
| |
|
| | def call(self, x, training=False): |
| | return self.dropout(tf.nn.relu(self.bn(self.conv(x), training=training)), |
| | training=training) |
| |
|
| | class Encoder(tf.keras.layers.Layer): |
| | def __init__(self, vocab_size, emb_dim=512, enc_dim=512, |
| | n_conv=3, conv_k=5, drop=0.5, **kw): |
| | super().__init__(**kw) |
| | self.emb = tf.keras.layers.Embedding(vocab_size, emb_dim) |
| | self.convs = [ConvBN(emb_dim, conv_k, drop) for _ in range(n_conv)] |
| | self.bilstm = tf.keras.layers.Bidirectional( |
| | tf.keras.layers.LSTM(enc_dim // 2, return_sequences=True), |
| | merge_mode='concat') |
| |
|
| | def call(self, x, training=False): |
| | x = self.emb(x) |
| | for c in self.convs: |
| | x = c(x, training=training) |
| | return self.bilstm(x) |
| |
|
| | class PostNet(tf.keras.layers.Layer): |
| | def __init__(self, n_mels=80, dim=512, n_layers=5, k=5, drop=0.5, **kw): |
| | super().__init__(**kw) |
| | self.layers_list = [] |
| | for i in range(n_layers): |
| | out = n_mels if i == n_layers - 1 else dim |
| | self.layers_list.append(( |
| | tf.keras.layers.Conv1D(out, k, padding='same'), |
| | tf.keras.layers.BatchNormalization(), |
| | tf.keras.layers.Dropout(drop), |
| | i == n_layers - 1)) |
| |
|
| | def call(self, x, training=False): |
| | h = x |
| | for conv, bn, drop, last in self.layers_list: |
| | h = drop((lambda v: v if last else tf.nn.tanh(v))( |
| | bn(conv(h), training=training)), training=training) |
| | return x + h |
| |
|
| | class Decoder(tf.keras.layers.Layer): |
| | def __init__(self, n_mels=80, dec_dim=1024, attn_dim=128, |
| | prenet_dim=256, max_steps=1000, **kw): |
| | super().__init__(**kw) |
| | self.n_mels = n_mels |
| | self.dec_dim = dec_dim |
| | self.max_steps = max_steps |
| | self.prenet = Prenet(prenet_dim) |
| | self.attention = LocationSensitiveAttention(attn_dim) |
| | self.lstm1 = tf.keras.layers.LSTMCell(dec_dim) |
| | self.lstm2 = tf.keras.layers.LSTMCell(dec_dim) |
| | self.mel_proj = tf.keras.layers.Dense(n_mels) |
| | self.stop_proj = tf.keras.layers.Dense(1) |
| | self.attn_proj = tf.keras.layers.Dense(dec_dim) |
| |
|
| | def _init_state(self, enc): |
| | B = tf.shape(enc)[0] |
| | T = tf.shape(enc)[1] |
| | enc_dim = tf.shape(enc)[2] |
| | return { |
| | 'attn_w': tf.zeros([B, T]), |
| | 's1': [tf.zeros([B, self.dec_dim]), tf.zeros([B, self.dec_dim])], |
| | 's2': [tf.zeros([B, self.dec_dim]), tf.zeros([B, self.dec_dim])], |
| | 'ctx': tf.zeros([B, enc_dim]), |
| | 'mel': tf.zeros([B, self.n_mels]), |
| | } |
| |
|
| | def _step(self, inp, enc, st): |
| | p = self.prenet(inp) |
| | x = tf.concat([p, st['ctx']], -1) |
| | h1, s1 = self.lstm1(x, st['s1']) |
| | h2, s2 = self.lstm2(h1, st['s2']) |
| | ctx, w = self.attention(h2, enc, st['attn_w']) |
| | out = self.attn_proj(tf.concat([h2, ctx], -1)) |
| | mel = self.mel_proj(out) |
| | stop = self.stop_proj(out) |
| | return mel, stop, w, {'attn_w': w, 's1': s1, 's2': s2, 'ctx': ctx, 'mel': mel} |
| |
|
| | def call(self, enc, mel_tgt=None, training=False): |
| | st = self._init_state(enc) |
| | mels, stops, attns = [], [], [] |
| |
|
| | |
| | use_teacher = (mel_tgt is not None) and training |
| |
|
| | if use_teacher: |
| | num_frames = mel_tgt.shape[1] |
| | for t in range(num_frames): |
| | inp = st['mel'] if t == 0 else mel_tgt[:, t - 1, :] |
| | m, s, w, st = self._step(inp, enc, st) |
| | mels.append(m) |
| | stops.append(s) |
| | attns.append(w) |
| | else: |
| | |
| | for t in range(self.max_steps): |
| | m, s, w, st = self._step(st['mel'], enc, st) |
| | mels.append(m) |
| | stops.append(s) |
| | attns.append(w) |
| | if t > 10 and tf.reduce_mean(tf.sigmoid(s)) > 0.5: |
| | break |
| |
|
| | return tf.stack(mels, 1), tf.stack(stops, 1), tf.stack(attns, 1) |
| |
|
| | class TeraVO(tf.keras.Model): |
| | def __init__(self, vocab_size, n_mels=80, emb_dim=256, enc_dim=256, |
| | dec_dim=512, attn_dim=128, prenet_dim=128, postnet_dim=256, |
| | n_conv=3, num_voices=3, voice_emb_dim=64, max_steps=800, **kw): |
| | super().__init__(**kw) |
| | self.voice_emb = tf.keras.layers.Embedding(num_voices, voice_emb_dim) |
| | self.voice_proj = tf.keras.layers.Dense(enc_dim) |
| | self.encoder = Encoder(vocab_size, emb_dim, enc_dim, n_conv) |
| | self.decoder = Decoder(n_mels, dec_dim, attn_dim, prenet_dim, max_steps) |
| | self.postnet = PostNet(n_mels, postnet_dim) |
| |
|
| | def call(self, inputs, training=False): |
| | text = inputs['text'] |
| | vid = inputs['voice_id'] |
| | mel_tgt = inputs.get('mel_target', None) |
| | enc = self.encoder(text, training=training) |
| | ve = self.voice_proj(self.voice_emb(vid)) |
| | enc = enc + tf.expand_dims(ve, 1) |
| | mel, stop, attn = self.decoder(enc, mel_tgt, training=training) |
| | mel_post = self.postnet(mel, training=training) |
| | return {'mel_outputs': mel, 'mel_outputs_postnet': mel_post, |
| | 'stop_tokens': stop, 'attention_weights': attn} |
| |
|
| | def create_model(vocab_size, n_mels=80, num_voices=3): |
| | return TeraVO(vocab_size, n_mels=n_mels, num_voices=num_voices) |
| |
|