Spaces:
Build error
Build error
PeteBleackley
commited on
Commit
·
e556cb6
1
Parent(s):
1b76f7d
Ensure weights are trainable
Browse files
qarac/models/layers/GlobalAttentionPoolingHead.py
CHANGED
|
@@ -47,8 +47,12 @@ class GlobalAttentionPoolingHead(keras.layers.Layer):
|
|
| 47 |
|
| 48 |
"""
|
| 49 |
width = input_shape[-1]
|
| 50 |
-
self.global_projection = self.add_weight('global projection',
|
| 51 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 52 |
self.built=True
|
| 53 |
|
| 54 |
@tensorflow.function
|
|
|
|
| 47 |
|
| 48 |
"""
|
| 49 |
width = input_shape[-1]
|
| 50 |
+
self.global_projection = self.add_weight('global projection',
|
| 51 |
+
shape=(width,width),
|
| 52 |
+
Trainable=True)
|
| 53 |
+
self.local_projection = self.add_weight('local projection',
|
| 54 |
+
shape=(width,width),
|
| 55 |
+
trainable=True)
|
| 56 |
self.built=True
|
| 57 |
|
| 58 |
@tensorflow.function
|