| import torch |
| from transformers import T5ForConditionalGeneration |
| from transformers.modeling_outputs import Seq2SeqLMOutput |
|
|
|
|
| class T5ForPretrain(T5ForConditionalGeneration): |
| def __init__(self, config): |
| super().__init__(config) |
| def forward( |
| self, |
| input_ids=None, |
| attention_mask=None, |
| decoder_input_ids=None, |
| decoder_attention_mask=None, |
| head_mask=None, |
| decoder_head_mask=None, |
| cross_attn_head_mask=None, |
| encoder_outputs=None, |
| past_key_values=None, |
| inputs_embeds=None, |
| decoder_inputs_embeds=None, |
| labels=None, |
| use_cache=None, |
| output_attentions=None, |
| output_hidden_states=None, |
| return_dict=True, |
| **kwargs, |
| ): |
| return super().forward( |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| decoder_input_ids=decoder_input_ids, |
| decoder_attention_mask=decoder_attention_mask, |
| head_mask=head_mask, |
| decoder_head_mask=decoder_head_mask, |
| cross_attn_head_mask=cross_attn_head_mask, |
| encoder_outputs=encoder_outputs, |
| past_key_values=past_key_values, |
| inputs_embeds=inputs_embeds, |
| decoder_inputs_embeds=decoder_inputs_embeds, |
| labels=labels, |
| use_cache=use_cache, |
| output_attentions=output_attentions, |
| output_hidden_states=output_hidden_states, |
| return_dict=return_dict, |
| ) |
|
|
|
|
| class T5ForPretrainDPO(T5ForPretrain): |
| pass |
|
|