Update modeling_protonet.py
Browse files- modeling_protonet.py +2 -8
modeling_protonet.py
CHANGED
|
@@ -854,7 +854,8 @@ class AudioProtoNetPreTrainedModel(PreTrainedModel):
|
|
| 854 |
nn.init.trunc_normal_(module.weight, std=0.02)
|
| 855 |
if module.bias is not None:
|
| 856 |
nn.init.zeros_(module.bias)
|
| 857 |
-
if isinstance(
|
|
|
|
| 858 |
self.last_layer.weight.data.fill_(self.correct_class_connection)
|
| 859 |
|
| 860 |
|
|
@@ -896,13 +897,6 @@ class AudioProtoNetForSequenceClassification(AudioProtoNetPreTrainedModel):
|
|
| 896 |
self.model = AudioProtoNetModel(config)
|
| 897 |
self.head = AudioProtoNetClassificationHead(config)
|
| 898 |
|
| 899 |
-
|
| 900 |
-
def freeze_backbone(self):
|
| 901 |
-
pass
|
| 902 |
-
|
| 903 |
-
def int2str(self): # TODO
|
| 904 |
-
pass
|
| 905 |
-
|
| 906 |
def forward(
|
| 907 |
self,
|
| 908 |
input_values: torch.Tensor,
|
|
|
|
| 854 |
nn.init.trunc_normal_(module.weight, std=0.02)
|
| 855 |
if module.bias is not None:
|
| 856 |
nn.init.zeros_(module.bias)
|
| 857 |
+
if isinstance(module, LinearLayerWithoutNegativeConnections) and self.incorrect_class_connection is None:
|
| 858 |
+
# Initialize all weights to the correct_class_connection value
|
| 859 |
self.last_layer.weight.data.fill_(self.correct_class_connection)
|
| 860 |
|
| 861 |
|
|
|
|
| 897 |
self.model = AudioProtoNetModel(config)
|
| 898 |
self.head = AudioProtoNetClassificationHead(config)
|
| 899 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 900 |
def forward(
|
| 901 |
self,
|
| 902 |
input_values: torch.Tensor,
|