project / tkat.py
nikethanreddy's picture
Update tkat.py
5e5afe8 verified
raw
history blame
585 Bytes
import tensorflow as tf
from tensorflow.keras.layers import Layer, Dense
class TKAT(Layer):
def __init__(self, units):
super(AttentionLayer, self).__init__()
self.W1 = Dense(units)
self.W2 = Dense(units)
self.V = Dense(1)
def call(self, query, values):
score = self.V(tf.nn.tanh(self.W1(query) + self.W2(values)))
attention_weights = tf.nn.softmax(score, axis=1)
context_vector = attention_weights * values
context_vector = tf.reduce_sum(context_vector, axis=1)
return context_vector, attention_weights