Update model_class.py
#5
by
geekweilai
- opened
- model_class.py +2 -0
model_class.py
CHANGED
|
@@ -26,6 +26,7 @@ class WhisperForAudioCaptioning(transformers.WhisperForConditionalGeneration):
|
|
| 26 |
output_hidden_states: Optional[bool] = None,
|
| 27 |
return_dict: Optional[bool] = None,
|
| 28 |
forced_ac_decoder_ids: Optional[torch.LongTensor] = None, # added to be ignored when passed from trainer
|
|
|
|
| 29 |
) -> Union[Tuple[torch.Tensor], Seq2SeqLMOutput]:
|
| 30 |
return super().forward(
|
| 31 |
input_features=input_features,
|
|
@@ -43,6 +44,7 @@ class WhisperForAudioCaptioning(transformers.WhisperForConditionalGeneration):
|
|
| 43 |
output_attentions=output_attentions,
|
| 44 |
output_hidden_states=output_hidden_states,
|
| 45 |
return_dict=return_dict,
|
|
|
|
| 46 |
)
|
| 47 |
|
| 48 |
# copy-pasted and adapted from transformers.WhisperForConditionalGeneration.generate
|
|
|
|
| 26 |
output_hidden_states: Optional[bool] = None,
|
| 27 |
return_dict: Optional[bool] = None,
|
| 28 |
forced_ac_decoder_ids: Optional[torch.LongTensor] = None, # added to be ignored when passed from trainer
|
| 29 |
+
decoder_position_ids: Optional[torch.LongTensor] = None,
|
| 30 |
) -> Union[Tuple[torch.Tensor], Seq2SeqLMOutput]:
|
| 31 |
return super().forward(
|
| 32 |
input_features=input_features,
|
|
|
|
| 44 |
output_attentions=output_attentions,
|
| 45 |
output_hidden_states=output_hidden_states,
|
| 46 |
return_dict=return_dict,
|
| 47 |
+
decoder_position_ids=decoder_position_ids,
|
| 48 |
)
|
| 49 |
|
| 50 |
# copy-pasted and adapted from transformers.WhisperForConditionalGeneration.generate
|