Spaces:
Sleeping
Sleeping
File size: 585 Bytes
1e1a4ef 5e5afe8 1e1a4ef | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 | import tensorflow as tf
from tensorflow.keras.layers import Layer, Dense
class TKAT(Layer):
def __init__(self, units):
super(AttentionLayer, self).__init__()
self.W1 = Dense(units)
self.W2 = Dense(units)
self.V = Dense(1)
def call(self, query, values):
score = self.V(tf.nn.tanh(self.W1(query) + self.W2(values)))
attention_weights = tf.nn.softmax(score, axis=1)
context_vector = attention_weights * values
context_vector = tf.reduce_sum(context_vector, axis=1)
return context_vector, attention_weights
|