bn_multi_tribe_mt / src /pipes /models.py
MasumBhuiyan's picture
Updated trainer
6543d58
import tensorflow as tf
class Seq2Seq:
def __init__(self,
input_vocab_size,
output_vocab_size,
embedding_dim,
hidden_units):
self.epochs = 10
self.batch_size = 64
self.metrics = ['accuracy']
self.loss = tf.keras.losses.SparseCategoricalCrossentropy()
self.optimizer = tf.keras.optimizers.Adam()
self.model = None
self.embedding_dim = embedding_dim
self.hidden_units = hidden_units
self.input_vocab_size = input_vocab_size
self.output_vocab_size = output_vocab_size
self.encoder_embedding = tf.keras.layers.Embedding(self.input_vocab_size, self.embedding_dim)
self.encoder = tf.keras.layers.LSTM(self.hidden_units, return_sequences=True, return_state=True)
self.decoder_embedding = tf.keras.layers.Embedding(self.output_vocab_size, self.embedding_dim)
self.decoder = tf.keras.layers.LSTM(self.hidden_units, return_sequences=True, return_state=True)
self.output_layer = tf.keras.layers.Dense(self.output_vocab_size, activation='softmax')
def build(self):
encoder_inputs = tf.keras.Input(shape=(None,))
encoder_embedding = self.encoder_embedding(encoder_inputs)
encoder_outputs, state_h, state_c = self.encoder(encoder_embedding)
encoder_states = [state_h, state_c]
decoder_inputs = tf.keras.Input(shape=(None,))
decoder_embedding = self.decoder_embedding(decoder_inputs)
decoder_outputs, _, _ = self.decoder(decoder_embedding, initial_state=encoder_states)
outputs = self.output_layer(decoder_outputs)
self.model = tf.keras.Model([encoder_inputs, decoder_inputs], outputs)
def get(self):
return self.model