Nah_kagz1092 commited on
Commit
301ca84
·
verified ·
1 Parent(s): 93d7635

Create model.py

Browse files
Files changed (1) hide show
  1. model.py +12 -0
model.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tensorflow as tf
2
+
3
+ class TransformerModel(tf.keras.Model):
4
+ def __init__(self, config):
5
+ super(TransformerModel, self).__init__()
6
+ self.encoder = tf.keras.layers.TransformerEncoder(config["encoder_layers"])
7
+ self.decoder = tf.keras.layers.TransformerDecoder(config["decoder_layers"])
8
+
9
+ def call(self, inputs, targets):
10
+ encoder_output = self.encoder(inputs)
11
+ decoder_output = self.decoder(targets, encoder_output)
12
+ return decoder_output