Update modeling_protonet.py
Browse files- modeling_protonet.py +6 -24
modeling_protonet.py
CHANGED
|
@@ -854,7 +854,7 @@ class AudioProtoNetPreTrainedModel(PreTrainedModel):
|
|
| 854 |
nn.init.trunc_normal_(module.weight, std=0.02)
|
| 855 |
if module.bias is not None:
|
| 856 |
nn.init.zeros_(module.bias)
|
| 857 |
-
if isinstance(module, LinearLayerWithoutNegativeConnections) and self.incorrect_class_connection is None:
|
| 858 |
# Initialize all weights to the correct_class_connection value
|
| 859 |
self.last_layer.weight.data.fill_(self.correct_class_connection)
|
| 860 |
|
|
@@ -870,14 +870,12 @@ class AudioProtoNetModel(AudioProtoNetPreTrainedModel):
|
|
| 870 |
def forward(
|
| 871 |
self,
|
| 872 |
input_values: torch.Tensor,
|
| 873 |
-
output_hidden_states: bool = None
|
| 874 |
-
|
| 875 |
-
) -> tuple | BaseModelOutputWithPoolingAndNoAttention:
|
| 876 |
"""
|
| 877 |
Args:
|
| 878 |
input_values:
|
| 879 |
output_hidden_states:
|
| 880 |
-
return_dict:
|
| 881 |
|
| 882 |
Returns:
|
| 883 |
last_hidden_state: torch.FloatTensor = None
|
|
@@ -885,7 +883,7 @@ class AudioProtoNetModel(AudioProtoNetPreTrainedModel):
|
|
| 885 |
hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
|
| 886 |
|
| 887 |
"""
|
| 888 |
-
return self.backbone(input_values, output_hidden_states
|
| 889 |
|
| 890 |
|
| 891 |
class AudioProtoNetForSequenceClassification(AudioProtoNetPreTrainedModel):
|
|
@@ -897,13 +895,6 @@ class AudioProtoNetForSequenceClassification(AudioProtoNetPreTrainedModel):
|
|
| 897 |
self.model = AudioProtoNetModel(config)
|
| 898 |
self.head = AudioProtoNetClassificationHead(config)
|
| 899 |
|
| 900 |
-
|
| 901 |
-
def freeze_backbone(self):
|
| 902 |
-
pass
|
| 903 |
-
|
| 904 |
-
def int2str(self): # TODO
|
| 905 |
-
pass
|
| 906 |
-
|
| 907 |
def forward(
|
| 908 |
self,
|
| 909 |
input_values: torch.Tensor,
|
|
@@ -911,10 +902,9 @@ class AudioProtoNetForSequenceClassification(AudioProtoNetPreTrainedModel):
|
|
| 911 |
prototypes_of_wrong_class: torch.Tensor = None,
|
| 912 |
output_hidden_states: bool = None,
|
| 913 |
output_prototypical_activations: bool = None,
|
| 914 |
-
|
| 915 |
-
) -> tuple | SequenceClassifierOutputWithProtoTypeActivations:
|
| 916 |
|
| 917 |
-
backbone_outputs = self.model(input_values, output_hidden_states
|
| 918 |
|
| 919 |
last_hidden_state = backbone_outputs[0]
|
| 920 |
|
|
@@ -936,14 +926,6 @@ class AudioProtoNetForSequenceClassification(AudioProtoNetPreTrainedModel):
|
|
| 936 |
if output_prototypical_activations is not None:
|
| 937 |
prototype_activations = info[4]
|
| 938 |
|
| 939 |
-
if return_dict:
|
| 940 |
-
output = (logits,)
|
| 941 |
-
output += (loss, ) if loss is not None else ()
|
| 942 |
-
output += (last_hidden_state, )
|
| 943 |
-
output += (hidden_states, ) if hidden_states is not None else ()
|
| 944 |
-
output += (prototype_activations,) if prototype_activations is not None else ()
|
| 945 |
-
return output
|
| 946 |
-
|
| 947 |
return SequenceClassifierOutputWithProtoTypeActivations(
|
| 948 |
logits=logits,
|
| 949 |
loss=loss,
|
|
|
|
| 854 |
nn.init.trunc_normal_(module.weight, std=0.02)
|
| 855 |
if module.bias is not None:
|
| 856 |
nn.init.zeros_(module.bias)
|
| 857 |
+
if isinstance(module, LinearLayerWithoutNegativeConnections) and self.incorrect_class_connection is None:
|
| 858 |
# Initialize all weights to the correct_class_connection value
|
| 859 |
self.last_layer.weight.data.fill_(self.correct_class_connection)
|
| 860 |
|
|
|
|
| 870 |
def forward(
|
| 871 |
self,
|
| 872 |
input_values: torch.Tensor,
|
| 873 |
+
output_hidden_states: bool = None
|
| 874 |
+
) -> BaseModelOutputWithPoolingAndNoAttention:
|
|
|
|
| 875 |
"""
|
| 876 |
Args:
|
| 877 |
input_values:
|
| 878 |
output_hidden_states:
|
|
|
|
| 879 |
|
| 880 |
Returns:
|
| 881 |
last_hidden_state: torch.FloatTensor = None
|
|
|
|
| 883 |
hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
|
| 884 |
|
| 885 |
"""
|
| 886 |
+
return self.backbone(input_values, output_hidden_states)
|
| 887 |
|
| 888 |
|
| 889 |
class AudioProtoNetForSequenceClassification(AudioProtoNetPreTrainedModel):
|
|
|
|
| 895 |
self.model = AudioProtoNetModel(config)
|
| 896 |
self.head = AudioProtoNetClassificationHead(config)
|
| 897 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 898 |
def forward(
|
| 899 |
self,
|
| 900 |
input_values: torch.Tensor,
|
|
|
|
| 902 |
prototypes_of_wrong_class: torch.Tensor = None,
|
| 903 |
output_hidden_states: bool = None,
|
| 904 |
output_prototypical_activations: bool = None,
|
| 905 |
+
) -> SequenceClassifierOutputWithProtoTypeActivations:
|
|
|
|
| 906 |
|
| 907 |
+
backbone_outputs = self.model(input_values, output_hidden_states)
|
| 908 |
|
| 909 |
last_hidden_state = backbone_outputs[0]
|
| 910 |
|
|
|
|
| 926 |
if output_prototypical_activations is not None:
|
| 927 |
prototype_activations = info[4]
|
| 928 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 929 |
return SequenceClassifierOutputWithProtoTypeActivations(
|
| 930 |
logits=logits,
|
| 931 |
loss=loss,
|