PeteBleackley commited on
Commit
e556cb6
·
1 Parent(s): 1b76f7d

Ensure weights are trainable

Browse files
qarac/models/layers/GlobalAttentionPoolingHead.py CHANGED
@@ -47,8 +47,12 @@ class GlobalAttentionPoolingHead(keras.layers.Layer):
47
 
48
  """
49
  width = input_shape[-1]
50
- self.global_projection = self.add_weight('global projection',shape=(width,width))
51
- self.local_projection = self.add_weight('local projection',shape=(width,width))
 
 
 
 
52
  self.built=True
53
 
54
  @tensorflow.function
 
47
 
48
  """
49
  width = input_shape[-1]
50
+ self.global_projection = self.add_weight('global projection',
51
+ shape=(width,width),
52
+ Trainable=True)
53
+ self.local_projection = self.add_weight('local projection',
54
+ shape=(width,width),
55
+ trainable=True)
56
  self.built=True
57
 
58
  @tensorflow.function