Update modeling_protonet.py
Browse files- modeling_protonet.py +5 -16
modeling_protonet.py
CHANGED
|
@@ -870,14 +870,12 @@ class AudioProtoNetModel(AudioProtoNetPreTrainedModel):
|
|
| 870 |
def forward(
|
| 871 |
self,
|
| 872 |
input_values: torch.Tensor,
|
| 873 |
-
output_hidden_states: bool = None
|
| 874 |
-
|
| 875 |
-
) -> tuple | BaseModelOutputWithPoolingAndNoAttention:
|
| 876 |
"""
|
| 877 |
Args:
|
| 878 |
input_values:
|
| 879 |
output_hidden_states:
|
| 880 |
-
return_dict:
|
| 881 |
|
| 882 |
Returns:
|
| 883 |
last_hidden_state: torch.FloatTensor = None
|
|
@@ -885,7 +883,7 @@ class AudioProtoNetModel(AudioProtoNetPreTrainedModel):
|
|
| 885 |
hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
|
| 886 |
|
| 887 |
"""
|
| 888 |
-
return self.backbone(input_values, output_hidden_states
|
| 889 |
|
| 890 |
|
| 891 |
class AudioProtoNetForSequenceClassification(AudioProtoNetPreTrainedModel):
|
|
@@ -904,10 +902,9 @@ class AudioProtoNetForSequenceClassification(AudioProtoNetPreTrainedModel):
|
|
| 904 |
prototypes_of_wrong_class: torch.Tensor = None,
|
| 905 |
output_hidden_states: bool = None,
|
| 906 |
output_prototypical_activations: bool = None,
|
| 907 |
-
|
| 908 |
-
) -> tuple | SequenceClassifierOutputWithProtoTypeActivations:
|
| 909 |
|
| 910 |
-
backbone_outputs = self.model(input_values, output_hidden_states
|
| 911 |
|
| 912 |
last_hidden_state = backbone_outputs[0]
|
| 913 |
|
|
@@ -929,14 +926,6 @@ class AudioProtoNetForSequenceClassification(AudioProtoNetPreTrainedModel):
|
|
| 929 |
if output_prototypical_activations is not None:
|
| 930 |
prototype_activations = info[4]
|
| 931 |
|
| 932 |
-
if return_dict:
|
| 933 |
-
output = (logits,)
|
| 934 |
-
output += (loss, ) if loss is not None else ()
|
| 935 |
-
output += (last_hidden_state, )
|
| 936 |
-
output += (hidden_states, ) if hidden_states is not None else ()
|
| 937 |
-
output += (prototype_activations,) if prototype_activations is not None else ()
|
| 938 |
-
return output
|
| 939 |
-
|
| 940 |
return SequenceClassifierOutputWithProtoTypeActivations(
|
| 941 |
logits=logits,
|
| 942 |
loss=loss,
|
|
|
|
| 870 |
def forward(
|
| 871 |
self,
|
| 872 |
input_values: torch.Tensor,
|
| 873 |
+
output_hidden_states: bool = None
|
| 874 |
+
) -> BaseModelOutputWithPoolingAndNoAttention:
|
|
|
|
| 875 |
"""
|
| 876 |
Args:
|
| 877 |
input_values:
|
| 878 |
output_hidden_states:
|
|
|
|
| 879 |
|
| 880 |
Returns:
|
| 881 |
last_hidden_state: torch.FloatTensor = None
|
|
|
|
| 883 |
hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
|
| 884 |
|
| 885 |
"""
|
| 886 |
+
return self.backbone(input_values, output_hidden_states)
|
| 887 |
|
| 888 |
|
| 889 |
class AudioProtoNetForSequenceClassification(AudioProtoNetPreTrainedModel):
|
|
|
|
| 902 |
prototypes_of_wrong_class: torch.Tensor = None,
|
| 903 |
output_hidden_states: bool = None,
|
| 904 |
output_prototypical_activations: bool = None,
|
| 905 |
+
) -> SequenceClassifierOutputWithProtoTypeActivations:
|
|
|
|
| 906 |
|
| 907 |
+
backbone_outputs = self.model(input_values, output_hidden_states)
|
| 908 |
|
| 909 |
last_hidden_state = backbone_outputs[0]
|
| 910 |
|
|
|
|
| 926 |
if output_prototypical_activations is not None:
|
| 927 |
prototype_activations = info[4]
|
| 928 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 929 |
return SequenceClassifierOutputWithProtoTypeActivations(
|
| 930 |
logits=logits,
|
| 931 |
loss=loss,
|