import tensorflow as tf from tensorflow.keras.layers import Layer, Dense class TKAT(Layer): def __init__(self, units): super(AttentionLayer, self).__init__() self.W1 = Dense(units) self.W2 = Dense(units) self.V = Dense(1) def call(self, query, values): score = self.V(tf.nn.tanh(self.W1(query) + self.W2(values))) attention_weights = tf.nn.softmax(score, axis=1) context_vector = attention_weights * values context_vector = tf.reduce_sum(context_vector, axis=1) return context_vector, attention_weights