mwirth7 commited on
Commit
3c867d8
·
verified ·
1 Parent(s): 9bd92f6

Update modeling_protonet.py

Browse files
Files changed (1) hide show
  1. modeling_protonet.py +1 -1
modeling_protonet.py CHANGED
@@ -854,7 +854,7 @@ class AudioProtoNetPreTrainedModel(PreTrainedModel):
854
  nn.init.trunc_normal_(module.weight, std=0.02)
855
  if module.bias is not None:
856
  nn.init.zeros_(module.bias)
857
- if self.incorrect_class_connection is None and isinstance(self.last_layer, LinearLayerWithoutNegativeConnections): # TODO missing initilization
858
  # Initialize all weights to the correct_class_connection value
859
  self.last_layer.weight.data.fill_(self.correct_class_connection)
860
 
 
854
  nn.init.trunc_normal_(module.weight, std=0.02)
855
  if module.bias is not None:
856
  nn.init.zeros_(module.bias)
857
+ if isinstance(module, LinearLayerWithoutNegativeConnections) and self.incorrect_class_connection is None: # TODO missing initilization
858
  # Initialize all weights to the correct_class_connection value
859
  self.last_layer.weight.data.fill_(self.correct_class_connection)
860