Spaces:
Build error
Build error
PeteBleackley commited on
Commit ·
8d80339
1
Parent(s): c284c9a
dot_prod needs to unpack arguments from tuple
Browse files
qarac/models/layers/GlobalAttentionPoolingHead.py
CHANGED
|
@@ -11,7 +11,8 @@ import tensorflow
|
|
| 11 |
|
| 12 |
|
| 13 |
@tensorflow.function
|
| 14 |
-
def dot_prod(
|
|
|
|
| 15 |
return tensorflow.tensordot(x,y,axes=1)
|
| 16 |
|
| 17 |
|
|
@@ -81,6 +82,6 @@ class GlobalAttentionPoolingHead(keras.layers.Layer):
|
|
| 81 |
lp = tensorflow.linalg.l2_normalize(tensorflow.vectorized_map(self.project_local,
|
| 82 |
X),
|
| 83 |
axis=2)
|
| 84 |
-
attention = tensorflow.vectorized_map(dot_prod,
|
| 85 |
return tensorflow.reduce_sum(attention *X,
|
| 86 |
axis=1)
|
|
|
|
| 11 |
|
| 12 |
|
| 13 |
@tensorflow.function
|
| 14 |
+
def dot_prod(vectors):
|
| 15 |
+
(x,y) = vectors
|
| 16 |
return tensorflow.tensordot(x,y,axes=1)
|
| 17 |
|
| 18 |
|
|
|
|
| 82 |
lp = tensorflow.linalg.l2_normalize(tensorflow.vectorized_map(self.project_local,
|
| 83 |
X),
|
| 84 |
axis=2)
|
| 85 |
+
attention = tensorflow.vectorized_map(dot_prod,(lp,gp))
|
| 86 |
return tensorflow.reduce_sum(attention *X,
|
| 87 |
axis=1)
|