Spaces:
Build error
Build error
| """ | |
| Copyright (c) 2022, salesforce.com, inc. | |
| All rights reserved. | |
| SPDX-License-Identifier: BSD-3-Clause | |
| For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause | |
| """ | |
| import torch | |
| import torch.nn as nn | |
| from lavis.common.registry import registry | |
| from lavis.models.base_model import BaseModel | |
| from torch.nn import CrossEntropyLoss, MSELoss | |
| from transformers import GPT2LMHeadModel | |
| from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions | |
| class GPTDialogue(BaseModel, GPT2LMHeadModel): | |
| PRETRAINED_MODEL_CONFIG_DICT = {"base": "configs/models/gpt_dialogue_base.yaml"} | |
| def __init__(self, config, len_video_ft=4224): | |
| super().__init__(config) | |
| self.video_ff = nn.Linear(len_video_ft, config.n_embd) | |
| self.video_ff_out = nn.Linear(config.n_embd, len_video_ft) | |
| # Model parallel | |
| self.model_parallel = False | |
| self.device_map = None | |
| # Initialize weights and apply final processing | |
| self.post_init() | |
| def forward( | |
| self, | |
| samples, | |
| past_key_values=None, | |
| position_ids=None, | |
| head_mask=None, | |
| encoder_hidden_states=None, | |
| encoder_attention_mask=None, | |
| use_cache=None, | |
| output_attentions=None, | |
| output_hidden_states=None, | |
| return_dict=None, | |
| ): | |
| input_embs = self.transformer.wte(samples["input_ids"]) | |
| video_embs = self.video_ff(samples["video_fts"]) | |
| input_embs = torch.cat([video_embs, input_embs], dim=1) | |
| transformer_outputs = self.transformer( | |
| attention_mask=samples["attn_mask"], | |
| token_type_ids=samples["token_type_ids"], | |
| inputs_embeds=input_embs, | |
| position_ids=position_ids, | |
| head_mask=head_mask, | |
| encoder_hidden_states=encoder_hidden_states, | |
| encoder_attention_mask=encoder_attention_mask, | |
| use_cache=use_cache, | |
| output_attentions=output_attentions, | |
| output_hidden_states=output_hidden_states, | |
| return_dict=return_dict, | |
| ) | |
| hidden_states = transformer_outputs[0] | |
| lm_logits = self.lm_head(hidden_states) | |
| loss = None | |
| if samples["labels"] is not None: | |
| # Shift so that tokens < n predict n | |
| shift_logits = lm_logits[..., :-1, :].contiguous() | |
| shift_labels = samples["labels"][..., 1:].contiguous() | |
| # Flatten the tokens | |
| loss_fct = CrossEntropyLoss(ignore_index=-1) | |
| loss = loss_fct( | |
| shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1) | |
| ) | |
| if samples["video_fts"] is not None: | |
| len_video_fts = samples["video_fts"].shape[1] | |
| video_logits = self.video_ff_out(hidden_states[:, :len_video_fts, :]) | |
| # Shift so that tokens < n predict n | |
| shift_logits = video_logits[..., :-1, :].contiguous() | |
| shift_labels = samples["video_fts"][..., 1:, :].contiguous() | |
| # Flatten the tokens | |
| loss_fct = MSELoss(reduction="mean") | |
| video_loss = loss_fct(shift_logits, shift_labels) | |
| if loss is not None: | |
| loss = loss + video_loss | |
| else: | |
| loss = video_loss | |
| return CausalLMOutputWithCrossAttentions( | |
| loss=loss, | |
| logits=lm_logits, | |
| past_key_values=transformer_outputs.past_key_values, | |
| hidden_states=transformer_outputs.hidden_states, | |
| attentions=transformer_outputs.attentions, | |
| cross_attentions=transformer_outputs.cross_attentions, | |
| ) | |
| def from_config(cls, cfg): | |
| model = cls.__bases__[1].from_pretrained("gpt2") | |
| model.resize_token_embeddings(cfg["len_tokenizer"]) | |
| return model | |