Spaces:
Running
Running
| from typing import Optional, Union, Tuple, List | |
| import torch | |
| from transformers import VisionEncoderDecoderModel | |
| from transformers.modeling_outputs import Seq2SeqLMOutput, BaseModelOutput | |
| class OrderVisionEncoderDecoderModel(VisionEncoderDecoderModel): | |
| def forward( | |
| self, | |
| pixel_values: Optional[torch.FloatTensor] = None, | |
| decoder_input_boxes: torch.LongTensor = None, | |
| # Shape (batch_size, num_boxes, 4), all coords scaled 0 - 1000, with 1001 as padding | |
| decoder_input_boxes_mask: torch.LongTensor = None, # Shape (batch_size, num_boxes), 0 if padding, 1 otherwise | |
| decoder_input_boxes_counts: torch.LongTensor = None, # Shape (batch_size), number of boxes in each image | |
| encoder_outputs: Optional[Tuple[torch.FloatTensor]] = None, | |
| past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, | |
| decoder_inputs_embeds: Optional[torch.FloatTensor] = None, | |
| labels: Optional[List[List[int]]] = None, | |
| use_cache: Optional[bool] = None, | |
| output_attentions: Optional[bool] = None, | |
| output_hidden_states: Optional[bool] = None, | |
| return_dict: Optional[bool] = None, | |
| **kwargs, | |
| ) -> Union[Tuple[torch.FloatTensor], Seq2SeqLMOutput]: | |
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict | |
| kwargs_encoder = {argument: value for argument, value in kwargs.items() if not argument.startswith("decoder_")} | |
| kwargs_decoder = { | |
| argument[len("decoder_") :]: value for argument, value in kwargs.items() if argument.startswith("decoder_") | |
| } | |
| if encoder_outputs is None: | |
| if pixel_values is None: | |
| raise ValueError("You have to specify pixel_values") | |
| encoder_outputs = self.encoder( | |
| pixel_values=pixel_values, | |
| output_attentions=output_attentions, | |
| output_hidden_states=output_hidden_states, | |
| return_dict=return_dict, | |
| **kwargs_encoder, | |
| ) | |
| elif isinstance(encoder_outputs, tuple): | |
| encoder_outputs = BaseModelOutput(*encoder_outputs) | |
| encoder_hidden_states = encoder_outputs[0] | |
| # optionally project encoder_hidden_states | |
| if ( | |
| self.encoder.config.hidden_size != self.decoder.config.hidden_size | |
| and self.decoder.config.cross_attention_hidden_size is None | |
| ): | |
| encoder_hidden_states = self.enc_to_dec_proj(encoder_hidden_states) | |
| # else: | |
| encoder_attention_mask = None | |
| # Decode | |
| decoder_outputs = self.decoder( | |
| input_boxes=decoder_input_boxes, | |
| input_boxes_mask=decoder_input_boxes_mask, | |
| input_boxes_counts=decoder_input_boxes_counts, | |
| encoder_hidden_states=encoder_hidden_states, | |
| encoder_attention_mask=encoder_attention_mask, | |
| inputs_embeds=decoder_inputs_embeds, | |
| output_attentions=output_attentions, | |
| output_hidden_states=output_hidden_states, | |
| use_cache=use_cache, | |
| past_key_values=past_key_values, | |
| return_dict=return_dict, | |
| labels=labels, | |
| **kwargs_decoder, | |
| ) | |
| if not return_dict: | |
| return decoder_outputs + encoder_outputs | |
| return Seq2SeqLMOutput( | |
| loss=decoder_outputs.loss, | |
| logits=decoder_outputs.logits, | |
| past_key_values=decoder_outputs.past_key_values, | |
| decoder_hidden_states=decoder_outputs.hidden_states, | |
| decoder_attentions=decoder_outputs.attentions, | |
| cross_attentions=decoder_outputs.cross_attentions, | |
| encoder_last_hidden_state=encoder_outputs.last_hidden_state, | |
| encoder_hidden_states=encoder_outputs.hidden_states, | |
| encoder_attentions=encoder_outputs.attentions, | |
| ) | |