PeteBleackley commited on
Commit
8d80339
·
1 Parent(s): c284c9a

dot_prod needs to unpack arguments from tuple

Browse files
qarac/models/layers/GlobalAttentionPoolingHead.py CHANGED
@@ -11,7 +11,8 @@ import tensorflow
11
 
12
 
13
  @tensorflow.function
14
- def dot_prod(x,y):
 
15
  return tensorflow.tensordot(x,y,axes=1)
16
 
17
 
@@ -81,6 +82,6 @@ class GlobalAttentionPoolingHead(keras.layers.Layer):
81
  lp = tensorflow.linalg.l2_normalize(tensorflow.vectorized_map(self.project_local,
82
  X),
83
  axis=2)
84
- attention = tensorflow.vectorized_map(dot_prod,[lp,gp])
85
  return tensorflow.reduce_sum(attention *X,
86
  axis=1)
 
11
 
12
 
13
  @tensorflow.function
14
+ def dot_prod(vectors):
15
+ (x,y) = vectors
16
  return tensorflow.tensordot(x,y,axes=1)
17
 
18
 
 
82
  lp = tensorflow.linalg.l2_normalize(tensorflow.vectorized_map(self.project_local,
83
  X),
84
  axis=2)
85
+ attention = tensorflow.vectorized_map(dot_prod,(lp,gp))
86
  return tensorflow.reduce_sum(attention *X,
87
  axis=1)