Spaces:
Sleeping
Sleeping
| import tensorflow as tf | |
| from tensorflow.keras.layers import Layer, Dense | |
| class TKAT(Layer): | |
| def __init__(self, units): | |
| super(AttentionLayer, self).__init__() | |
| self.W1 = Dense(units) | |
| self.W2 = Dense(units) | |
| self.V = Dense(1) | |
| def call(self, query, values): | |
| score = self.V(tf.nn.tanh(self.W1(query) + self.W2(values))) | |
| attention_weights = tf.nn.softmax(score, axis=1) | |
| context_vector = attention_weights * values | |
| context_vector = tf.reduce_sum(context_vector, axis=1) | |
| return context_vector, attention_weights | |