Update modeling_protonet.py
Browse files- modeling_protonet.py +1 -1
modeling_protonet.py
CHANGED
|
@@ -854,7 +854,7 @@ class AudioProtoNetPreTrainedModel(PreTrainedModel):
|
|
| 854 |
nn.init.trunc_normal_(module.weight, std=0.02)
|
| 855 |
if module.bias is not None:
|
| 856 |
nn.init.zeros_(module.bias)
|
| 857 |
-
if self.incorrect_class_connection is None
|
| 858 |
# Initialize all weights to the correct_class_connection value
|
| 859 |
self.last_layer.weight.data.fill_(self.correct_class_connection)
|
| 860 |
|
|
|
|
| 854 |
nn.init.trunc_normal_(module.weight, std=0.02)
|
| 855 |
if module.bias is not None:
|
| 856 |
nn.init.zeros_(module.bias)
|
| 857 |
+
if isinstance(module, LinearLayerWithoutNegativeConnections) and self.incorrect_class_connection is None: # TODO missing initilization
|
| 858 |
# Initialize all weights to the correct_class_connection value
|
| 859 |
self.last_layer.weight.data.fill_(self.correct_class_connection)
|
| 860 |
|