Tera.VO / model.py
vedaco's picture
Upload model.py with huggingface_hub
0f0d3f1 verified
import tensorflow as tf
import numpy as np
class LocationSensitiveAttention(tf.keras.layers.Layer):
def __init__(self, attn_dim=128, attn_filters=32, attn_kernel=31, **kw):
super().__init__(**kw)
self.attn_dim = attn_dim
self.attn_filters = attn_filters
self.attn_kernel = attn_kernel
def build(self, input_shape):
d = self.attn_dim
self.W_query = tf.keras.layers.Dense(d, use_bias=False)
self.W_memory = tf.keras.layers.Dense(d, use_bias=False)
self.loc_conv = tf.keras.layers.Conv1D(
self.attn_filters, self.attn_kernel, padding='same', use_bias=False)
self.W_loc = tf.keras.layers.Dense(d, use_bias=False)
self.v = tf.keras.layers.Dense(1, use_bias=False)
self.b = self.add_weight(name='attn_bias', shape=[d],
initializer='zeros', trainable=True)
super().build(input_shape)
def call(self, query, memory, prev_weights):
q = self.W_query(tf.expand_dims(query, 1))
m = self.W_memory(memory)
loc = self.loc_conv(tf.expand_dims(prev_weights, -1))
loc = self.W_loc(loc)
e = self.v(tf.nn.tanh(q + m + loc + self.b))
e = tf.squeeze(e, -1)
w = tf.nn.softmax(e, axis=-1)
ctx = tf.reduce_sum(tf.expand_dims(w, -1) * memory, 1)
return ctx, w
class Prenet(tf.keras.layers.Layer):
def __init__(self, units=256, **kw):
super().__init__(**kw)
self.fc1 = tf.keras.layers.Dense(units, activation='relu')
self.fc2 = tf.keras.layers.Dense(units, activation='relu')
self.drop1 = tf.keras.layers.Dropout(0.5)
self.drop2 = tf.keras.layers.Dropout(0.5)
def call(self, x, **_):
x = self.drop1(self.fc1(x), training=True)
x = self.drop2(self.fc2(x), training=True)
return x
class ConvBN(tf.keras.layers.Layer):
def __init__(self, filters, kernel=5, drop=0.5, **kw):
super().__init__(**kw)
self.conv = tf.keras.layers.Conv1D(filters, kernel, padding='same')
self.bn = tf.keras.layers.BatchNormalization()
self.dropout = tf.keras.layers.Dropout(drop)
def call(self, x, training=False):
return self.dropout(tf.nn.relu(self.bn(self.conv(x), training=training)),
training=training)
class Encoder(tf.keras.layers.Layer):
def __init__(self, vocab_size, emb_dim=512, enc_dim=512,
n_conv=3, conv_k=5, drop=0.5, **kw):
super().__init__(**kw)
self.emb = tf.keras.layers.Embedding(vocab_size, emb_dim)
self.convs = [ConvBN(emb_dim, conv_k, drop) for _ in range(n_conv)]
self.bilstm = tf.keras.layers.Bidirectional(
tf.keras.layers.LSTM(enc_dim // 2, return_sequences=True),
merge_mode='concat')
def call(self, x, training=False):
x = self.emb(x)
for c in self.convs:
x = c(x, training=training)
return self.bilstm(x)
class PostNet(tf.keras.layers.Layer):
def __init__(self, n_mels=80, dim=512, n_layers=5, k=5, drop=0.5, **kw):
super().__init__(**kw)
self.layers_list = []
for i in range(n_layers):
out = n_mels if i == n_layers - 1 else dim
self.layers_list.append((
tf.keras.layers.Conv1D(out, k, padding='same'),
tf.keras.layers.BatchNormalization(),
tf.keras.layers.Dropout(drop),
i == n_layers - 1))
def call(self, x, training=False):
h = x
for conv, bn, drop, last in self.layers_list:
h = drop((lambda v: v if last else tf.nn.tanh(v))(
bn(conv(h), training=training)), training=training)
return x + h
class Decoder(tf.keras.layers.Layer):
def __init__(self, n_mels=80, dec_dim=1024, attn_dim=128,
prenet_dim=256, max_steps=1000, **kw):
super().__init__(**kw)
self.n_mels = n_mels
self.dec_dim = dec_dim
self.max_steps = max_steps
self.prenet = Prenet(prenet_dim)
self.attention = LocationSensitiveAttention(attn_dim)
self.lstm1 = tf.keras.layers.LSTMCell(dec_dim)
self.lstm2 = tf.keras.layers.LSTMCell(dec_dim)
self.mel_proj = tf.keras.layers.Dense(n_mels)
self.stop_proj = tf.keras.layers.Dense(1)
self.attn_proj = tf.keras.layers.Dense(dec_dim)
def _init_state(self, enc):
B = tf.shape(enc)[0]
T = tf.shape(enc)[1]
enc_dim = tf.shape(enc)[2]
return {
'attn_w': tf.zeros([B, T]),
's1': [tf.zeros([B, self.dec_dim]), tf.zeros([B, self.dec_dim])],
's2': [tf.zeros([B, self.dec_dim]), tf.zeros([B, self.dec_dim])],
'ctx': tf.zeros([B, enc_dim]),
'mel': tf.zeros([B, self.n_mels]),
}
def _step(self, inp, enc, st):
p = self.prenet(inp)
x = tf.concat([p, st['ctx']], -1)
h1, s1 = self.lstm1(x, st['s1'])
h2, s2 = self.lstm2(h1, st['s2'])
ctx, w = self.attention(h2, enc, st['attn_w'])
out = self.attn_proj(tf.concat([h2, ctx], -1))
mel = self.mel_proj(out)
stop = self.stop_proj(out)
return mel, stop, w, {'attn_w': w, 's1': s1, 's2': s2, 'ctx': ctx, 'mel': mel}
def call(self, enc, mel_tgt=None, training=False):
st = self._init_state(enc)
mels, stops, attns = [], [], []
# FIXED: use_teacher = True only when mel_tgt has actual data
use_teacher = (mel_tgt is not None) and training
if use_teacher:
num_frames = mel_tgt.shape[1]
for t in range(num_frames):
inp = st['mel'] if t == 0 else mel_tgt[:, t - 1, :]
m, s, w, st = self._step(inp, enc, st)
mels.append(m)
stops.append(s)
attns.append(w)
else:
# Inference mode - autoregressive
for t in range(self.max_steps):
m, s, w, st = self._step(st['mel'], enc, st)
mels.append(m)
stops.append(s)
attns.append(w)
if t > 10 and tf.reduce_mean(tf.sigmoid(s)) > 0.5:
break
return tf.stack(mels, 1), tf.stack(stops, 1), tf.stack(attns, 1)
class TeraVO(tf.keras.Model):
def __init__(self, vocab_size, n_mels=80, emb_dim=256, enc_dim=256,
dec_dim=512, attn_dim=128, prenet_dim=128, postnet_dim=256,
n_conv=3, num_voices=3, voice_emb_dim=64, max_steps=800, **kw):
super().__init__(**kw)
self.voice_emb = tf.keras.layers.Embedding(num_voices, voice_emb_dim)
self.voice_proj = tf.keras.layers.Dense(enc_dim)
self.encoder = Encoder(vocab_size, emb_dim, enc_dim, n_conv)
self.decoder = Decoder(n_mels, dec_dim, attn_dim, prenet_dim, max_steps)
self.postnet = PostNet(n_mels, postnet_dim)
def call(self, inputs, training=False):
text = inputs['text']
vid = inputs['voice_id']
mel_tgt = inputs.get('mel_target', None)
enc = self.encoder(text, training=training)
ve = self.voice_proj(self.voice_emb(vid))
enc = enc + tf.expand_dims(ve, 1)
mel, stop, attn = self.decoder(enc, mel_tgt, training=training)
mel_post = self.postnet(mel, training=training)
return {'mel_outputs': mel, 'mel_outputs_postnet': mel_post,
'stop_tokens': stop, 'attention_weights': attn}
def create_model(vocab_size, n_mels=80, num_voices=3):
return TeraVO(vocab_size, n_mels=n_mels, num_voices=num_voices)