PJAITEST1903 / model.py
Nah_kagz1092
Update model.py
1f1e697 verified
raw
history blame
607 Bytes
import tensorflow as tf
class TransformerModel(tf.keras.Model):
def __init__(self, config):
super(TransformerModel, self).__init__()
self.encoder = tf.keras.layers.TransformerEncoder(config["encoder_layers"])
self.decoder = tf.keras.layers.TransformerDecoder(config["decoder_layers"])
def call(self, inputs, targets):
encoder_output = self.encoder(inputs)
decoder_output = self.decoder(targets, encoder_output)
return decoder_output
def predict(self, input_data):
# Code để dự đoán đầu ra dựa trên đầu vào
pass