import tensorflow as tf import numpy as np class LocationSensitiveAttention(tf.keras.layers.Layer): def __init__(self, attn_dim=128, attn_filters=32, attn_kernel=31, **kw): super().__init__(**kw) self.attn_dim = attn_dim self.attn_filters = attn_filters self.attn_kernel = attn_kernel def build(self, input_shape): d = self.attn_dim self.W_query = tf.keras.layers.Dense(d, use_bias=False) self.W_memory = tf.keras.layers.Dense(d, use_bias=False) self.loc_conv = tf.keras.layers.Conv1D( self.attn_filters, self.attn_kernel, padding='same', use_bias=False) self.W_loc = tf.keras.layers.Dense(d, use_bias=False) self.v = tf.keras.layers.Dense(1, use_bias=False) self.b = self.add_weight(name='attn_bias', shape=[d], initializer='zeros', trainable=True) super().build(input_shape) def call(self, query, memory, prev_weights): q = self.W_query(tf.expand_dims(query, 1)) m = self.W_memory(memory) loc = self.loc_conv(tf.expand_dims(prev_weights, -1)) loc = self.W_loc(loc) e = self.v(tf.nn.tanh(q + m + loc + self.b)) e = tf.squeeze(e, -1) w = tf.nn.softmax(e, axis=-1) ctx = tf.reduce_sum(tf.expand_dims(w, -1) * memory, 1) return ctx, w class Prenet(tf.keras.layers.Layer): def __init__(self, units=256, **kw): super().__init__(**kw) self.fc1 = tf.keras.layers.Dense(units, activation='relu') self.fc2 = tf.keras.layers.Dense(units, activation='relu') self.drop1 = tf.keras.layers.Dropout(0.5) self.drop2 = tf.keras.layers.Dropout(0.5) def call(self, x, **_): x = self.drop1(self.fc1(x), training=True) x = self.drop2(self.fc2(x), training=True) return x class ConvBN(tf.keras.layers.Layer): def __init__(self, filters, kernel=5, drop=0.5, **kw): super().__init__(**kw) self.conv = tf.keras.layers.Conv1D(filters, kernel, padding='same') self.bn = tf.keras.layers.BatchNormalization() self.dropout = tf.keras.layers.Dropout(drop) def call(self, x, training=False): return self.dropout(tf.nn.relu(self.bn(self.conv(x), training=training)), training=training) class Encoder(tf.keras.layers.Layer): def __init__(self, vocab_size, emb_dim=512, enc_dim=512, n_conv=3, conv_k=5, drop=0.5, **kw): super().__init__(**kw) self.emb = tf.keras.layers.Embedding(vocab_size, emb_dim) self.convs = [ConvBN(emb_dim, conv_k, drop) for _ in range(n_conv)] self.bilstm = tf.keras.layers.Bidirectional( tf.keras.layers.LSTM(enc_dim // 2, return_sequences=True), merge_mode='concat') def call(self, x, training=False): x = self.emb(x) for c in self.convs: x = c(x, training=training) return self.bilstm(x) class PostNet(tf.keras.layers.Layer): def __init__(self, n_mels=80, dim=512, n_layers=5, k=5, drop=0.5, **kw): super().__init__(**kw) self.layers_list = [] for i in range(n_layers): out = n_mels if i == n_layers - 1 else dim self.layers_list.append(( tf.keras.layers.Conv1D(out, k, padding='same'), tf.keras.layers.BatchNormalization(), tf.keras.layers.Dropout(drop), i == n_layers - 1)) def call(self, x, training=False): h = x for conv, bn, drop, last in self.layers_list: h = drop((lambda v: v if last else tf.nn.tanh(v))( bn(conv(h), training=training)), training=training) return x + h class Decoder(tf.keras.layers.Layer): def __init__(self, n_mels=80, dec_dim=1024, attn_dim=128, prenet_dim=256, max_steps=1000, **kw): super().__init__(**kw) self.n_mels = n_mels self.dec_dim = dec_dim self.max_steps = max_steps self.prenet = Prenet(prenet_dim) self.attention = LocationSensitiveAttention(attn_dim) self.lstm1 = tf.keras.layers.LSTMCell(dec_dim) self.lstm2 = tf.keras.layers.LSTMCell(dec_dim) self.mel_proj = tf.keras.layers.Dense(n_mels) self.stop_proj = tf.keras.layers.Dense(1) self.attn_proj = tf.keras.layers.Dense(dec_dim) def _init_state(self, enc): B = tf.shape(enc)[0] T = tf.shape(enc)[1] enc_dim = tf.shape(enc)[2] return { 'attn_w': tf.zeros([B, T]), 's1': [tf.zeros([B, self.dec_dim]), tf.zeros([B, self.dec_dim])], 's2': [tf.zeros([B, self.dec_dim]), tf.zeros([B, self.dec_dim])], 'ctx': tf.zeros([B, enc_dim]), 'mel': tf.zeros([B, self.n_mels]), } def _step(self, inp, enc, st): p = self.prenet(inp) x = tf.concat([p, st['ctx']], -1) h1, s1 = self.lstm1(x, st['s1']) h2, s2 = self.lstm2(h1, st['s2']) ctx, w = self.attention(h2, enc, st['attn_w']) out = self.attn_proj(tf.concat([h2, ctx], -1)) mel = self.mel_proj(out) stop = self.stop_proj(out) return mel, stop, w, {'attn_w': w, 's1': s1, 's2': s2, 'ctx': ctx, 'mel': mel} def call(self, enc, mel_tgt=None, training=False): st = self._init_state(enc) mels, stops, attns = [], [], [] # FIXED: use_teacher = True only when mel_tgt has actual data use_teacher = (mel_tgt is not None) and training if use_teacher: num_frames = mel_tgt.shape[1] for t in range(num_frames): inp = st['mel'] if t == 0 else mel_tgt[:, t - 1, :] m, s, w, st = self._step(inp, enc, st) mels.append(m) stops.append(s) attns.append(w) else: # Inference mode - autoregressive for t in range(self.max_steps): m, s, w, st = self._step(st['mel'], enc, st) mels.append(m) stops.append(s) attns.append(w) if t > 10 and tf.reduce_mean(tf.sigmoid(s)) > 0.5: break return tf.stack(mels, 1), tf.stack(stops, 1), tf.stack(attns, 1) class TeraVO(tf.keras.Model): def __init__(self, vocab_size, n_mels=80, emb_dim=256, enc_dim=256, dec_dim=512, attn_dim=128, prenet_dim=128, postnet_dim=256, n_conv=3, num_voices=3, voice_emb_dim=64, max_steps=800, **kw): super().__init__(**kw) self.voice_emb = tf.keras.layers.Embedding(num_voices, voice_emb_dim) self.voice_proj = tf.keras.layers.Dense(enc_dim) self.encoder = Encoder(vocab_size, emb_dim, enc_dim, n_conv) self.decoder = Decoder(n_mels, dec_dim, attn_dim, prenet_dim, max_steps) self.postnet = PostNet(n_mels, postnet_dim) def call(self, inputs, training=False): text = inputs['text'] vid = inputs['voice_id'] mel_tgt = inputs.get('mel_target', None) enc = self.encoder(text, training=training) ve = self.voice_proj(self.voice_emb(vid)) enc = enc + tf.expand_dims(ve, 1) mel, stop, attn = self.decoder(enc, mel_tgt, training=training) mel_post = self.postnet(mel, training=training) return {'mel_outputs': mel, 'mel_outputs_postnet': mel_post, 'stop_tokens': stop, 'attention_weights': attn} def create_model(vocab_size, n_mels=80, num_voices=3): return TeraVO(vocab_size, n_mels=n_mels, num_voices=num_voices)