| |
| |
| |
|
|
| |
| from typing import ( |
| Optional, |
| Tuple, |
| Union, |
| List, |
| ) |
|
|
| |
| import torch |
| from torch import nn |
| from torch.nn import CrossEntropyLoss |
| from transformers import ( |
| BartConfig, |
| BartPretrainedModel, |
| ) |
| from transformers.modeling_outputs import Seq2SeqLMOutput |
| from transformers.models.bart.modeling_bart import shift_tokens_right |
|
|
| from transformers.utils import ( |
| add_end_docstrings, |
| add_start_docstrings, |
| add_start_docstrings_to_model_forward, |
| logging, |
| replace_return_docstrings, |
| ) |
|
|
| from .bart_model import BartCustomModel |
| from .config import BartCustomConfig |
| from .custom_constants import BartConstants |
| from .bart_generation_mixin import GenerationMixin |
| from .custom_outputs import CustomSeq2SeqLMOutput |
|
|
| logger = logging.get_logger(__name__) |
|
|
| @add_start_docstrings( |
| "The BART Model with a language modeling head. Can be used for summarization.", BartConstants.BART_START_DOCSTRING |
| ) |
| class BartCustomForConditionalGeneration(BartPretrainedModel, GenerationMixin): |
| base_model_prefix = "model" |
| _keys_to_ignore_on_load_missing = [r"final_logits_bias", r"lm_head\.weight"] |
|
|
| def __init__(self, config: BartCustomConfig): |
| super().__init__(config) |
| self.model = BartCustomModel(config) |
| self.register_buffer("final_logits_bias", torch.zeros((1, self.model.shared.num_embeddings))) |
| self.lm_head = nn.Linear(config.d_model, self.model.shared.num_embeddings, bias=False) |
|
|
| |
| self.post_init() |
|
|
| def get_encoder(self): |
| return self.model.get_encoder() |
|
|
| def get_decoder(self): |
| return self.model.get_decoder() |
|
|
| def resize_token_embeddings(self, new_num_tokens: int) -> nn.Embedding: |
| new_embeddings = super().resize_token_embeddings(new_num_tokens) |
| self._resize_final_logits_bias(new_num_tokens) |
| return new_embeddings |
|
|
| def _resize_final_logits_bias(self, new_num_tokens: int) -> None: |
| old_num_tokens = self.final_logits_bias.shape[-1] |
| if new_num_tokens <= old_num_tokens: |
| new_bias = self.final_logits_bias[:, :new_num_tokens] |
| else: |
| extra_bias = torch.zeros((1, new_num_tokens - old_num_tokens), device=self.final_logits_bias.device) |
| new_bias = torch.cat([self.final_logits_bias, extra_bias], dim=1) |
| self.register_buffer("final_logits_bias", new_bias) |
|
|
| def get_output_embeddings(self): |
| return self.lm_head |
|
|
| def set_output_embeddings(self, new_embeddings): |
| self.lm_head = new_embeddings |
|
|
| @add_start_docstrings_to_model_forward(BartConstants.BART_INPUTS_DOCSTRING) |
| @replace_return_docstrings(output_type=Seq2SeqLMOutput, config_class=BartConstants.CONFIG_FOR_DOC) |
| @add_end_docstrings(BartConstants.BART_GENERATION_EXAMPLE) |
| def forward( |
| self, |
| input_ids: torch.LongTensor = None, |
| attention_mask: Optional[torch.Tensor] = None, |
| decoder_input_ids: Optional[torch.LongTensor] = None, |
| decoder_attention_mask: Optional[torch.LongTensor] = None, |
| head_mask: Optional[torch.Tensor] = None, |
| decoder_head_mask: Optional[torch.Tensor] = None, |
| cross_attn_head_mask: Optional[torch.Tensor] = None, |
| encoder_outputs: Optional[List[torch.FloatTensor]] = None, |
| past_key_values: Optional[List[torch.FloatTensor]] = None, |
| inputs_embeds: Optional[torch.FloatTensor] = None, |
| decoder_inputs_embeds: Optional[torch.FloatTensor] = None, |
| labels: Optional[torch.LongTensor] = None, |
| use_cache: Optional[bool] = None, |
| output_attentions: Optional[bool] = None, |
| output_hidden_states: Optional[bool] = None, |
| return_dict: Optional[bool] = None, |
| input_commonsense_relations: Optional[torch.Tensor] = None, |
| reduce_ce=True, |
| ) -> Union[Tuple, CustomSeq2SeqLMOutput]: |
| r""" |
| labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): |
| Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., |
| config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored |
| (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. |
| |
| Returns: |
| """ |
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
| if labels is not None: |
| if use_cache: |
| logger.warning("The `use_cache` argument is changed to `False` since `labels` is provided.") |
| use_cache = False |
| if decoder_input_ids is None and decoder_inputs_embeds is None: |
| decoder_input_ids = shift_tokens_right( |
| labels, self.config.pad_token_id, self.config.decoder_start_token_id |
| ) |
| outputs = self.model( |
| input_ids, |
| attention_mask=attention_mask, |
| decoder_input_ids=decoder_input_ids, |
| encoder_outputs=encoder_outputs, |
| decoder_attention_mask=decoder_attention_mask, |
| head_mask=head_mask, |
| decoder_head_mask=decoder_head_mask, |
| cross_attn_head_mask=cross_attn_head_mask, |
| past_key_values=past_key_values, |
| inputs_embeds=inputs_embeds, |
| decoder_inputs_embeds=decoder_inputs_embeds, |
| use_cache=use_cache, |
| output_attentions=output_attentions, |
| output_hidden_states=output_hidden_states, |
| return_dict=return_dict, |
| relation_inputs=input_commonsense_relations |
| ) |
| lm_logits = self.lm_head(outputs[0]) + self.final_logits_bias |
|
|
| masked_lm_loss = None |
| if labels is not None: |
| loss_fct = CrossEntropyLoss(reduce=reduce_ce, ignore_index=self.config.pad_token_id) |
| masked_lm_loss = loss_fct(lm_logits.view(-1, self.config.vocab_size), labels.view(-1)) |
|
|
| if not return_dict: |
| output = (lm_logits,) + outputs[1:] |
| return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output |
|
|
| return CustomSeq2SeqLMOutput( |
| loss=masked_lm_loss, |
| logits=lm_logits, |
| past_key_values=outputs.past_key_values, |
| decoder_hidden_states=outputs.decoder_hidden_states, |
| decoder_attentions=outputs.decoder_attentions, |
| cross_attentions=outputs.cross_attentions, |
| encoder_last_hidden_state=outputs.encoder_last_hidden_state, |
| encoder_hidden_states=outputs.encoder_hidden_states, |
| encoder_attentions=outputs.encoder_attentions, |
| head_mask=outputs.encoder_head_mask |
| ) |
|
|
| def prepare_inputs_for_generation( |
| self, |
| decoder_input_ids, |
| past=None, |
| attention_mask=None, |
| head_mask=None, |
| decoder_head_mask=None, |
| cross_attn_head_mask=None, |
| use_cache=None, |
| encoder_outputs=None, |
| **kwargs |
| ): |
| |
| if past is not None: |
| decoder_input_ids = decoder_input_ids[:, -1:] |
|
|
| return { |
| "input_ids": None, |
| "encoder_outputs": encoder_outputs, |
| "past_key_values": past, |
| "decoder_input_ids": decoder_input_ids, |
| "attention_mask": attention_mask, |
| "head_mask": head_mask, |
| "decoder_head_mask": decoder_head_mask, |
| "cross_attn_head_mask": cross_attn_head_mask, |
| "use_cache": use_cache, |
| } |
|
|
| def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): |
| return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id) |
|
|
| @staticmethod |
| def _reorder_cache(past, beam_idx): |
| reordered_past = () |
| for layer_past in past: |
| |
| reordered_past += ( |
| tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:], |
| ) |
| return reordered_past |
|
|