project / tkan.py
nikethanreddy's picture
Upload 6 files
1e1a4ef verified
raw
history blame
564 Bytes
import tensorflow as tf
from tensorflow.keras import layers
class TKAN(tf.keras.layers.Layer):
def __init__(self, units, **kwargs):
super(TKAN, self).__init__(**kwargs)
self.units = units
self.dense1 = layers.Dense(units, activation="relu")
self.dense2 = layers.Dense(units, activation="relu")
def call(self, inputs):
x = self.dense1(inputs)
return self.dense2(x)
def get_config(self):
config = super(TKAN, self).get_config()
config.update({"units": self.units})
return config