| import tensorflow as tf | |
| class TransformerModel(tf.keras.Model): | |
| def __init__(self, config): | |
| super(TransformerModel, self).__init__() | |
| self.encoder = tf.keras.layers.Transformer(**config["encoder_params"]) | |
| self.decoder = tf.keras.layers.Transformer(**config["decoder_params"]) | |
| def call(self, inputs, targets): | |
| encoder_output = self.encoder(inputs) | |
| decoder_output = self.decoder(targets, encoder_output) | |
| return decoder_output | |