Spaces:
Running
Running
| from typing import Optional, Union, Tuple | |
| import torch | |
| from transformers import PreTrainedModel, VisionEncoderDecoderConfig, PretrainedConfig | |
| from transformers.modeling_outputs import Seq2SeqLMOutput, BaseModelOutput | |
| from transformers.models.vision_encoder_decoder.modeling_vision_encoder_decoder import shift_tokens_right | |
| from surya.model.recognition.encoder import DonutSwinModel | |
| from surya.model.recognition.decoder import SuryaOCRDecoder, SuryaOCRTextEncoder | |
| class OCREncoderDecoderModel(PreTrainedModel): | |
| config_class = VisionEncoderDecoderConfig | |
| base_model_prefix = "vision_encoder_decoder" | |
| main_input_name = "pixel_values" | |
| supports_gradient_checkpointing = True | |
| _supports_param_buffer_assignment = False | |
| def __init__( | |
| self, | |
| config: Optional[PretrainedConfig] = None, | |
| encoder: Optional[PreTrainedModel] = None, | |
| decoder: Optional[PreTrainedModel] = None, | |
| text_encoder: Optional[PreTrainedModel] = None, | |
| ): | |
| # initialize with config | |
| # make sure input & output embeddings is not tied | |
| config.tie_word_embeddings = False | |
| config.decoder.tie_word_embeddings = False | |
| super().__init__(config) | |
| if encoder is None: | |
| encoder = DonutSwinModel(config.encoder) | |
| if decoder is None: | |
| decoder = SuryaOCRDecoder(config.decoder, attn_implementation=config._attn_implementation) | |
| if text_encoder is None: | |
| text_encoder = SuryaOCRTextEncoder(config.text_encoder, attn_implementation=config._attn_implementation) | |
| self.encoder = encoder | |
| self.decoder = decoder | |
| self.text_encoder = text_encoder | |
| # make sure that the individual model's config refers to the shared config | |
| # so that the updates to the config will be synced | |
| self.encoder.config = self.config.encoder | |
| self.decoder.config = self.config.decoder | |
| self.text_encoder.config = self.config.text_encoder | |
| def get_encoder(self): | |
| return self.encoder | |
| def get_decoder(self): | |
| return self.decoder | |
| def get_output_embeddings(self): | |
| return self.decoder.get_output_embeddings() | |
| def set_output_embeddings(self, new_embeddings): | |
| return self.decoder.set_output_embeddings(new_embeddings) | |
| def forward( | |
| self, | |
| pixel_values: Optional[torch.FloatTensor] = None, | |
| decoder_input_ids: Optional[torch.LongTensor] = None, | |
| decoder_cache_position: Optional[torch.LongTensor] = None, | |
| decoder_attention_mask: Optional[torch.BoolTensor] = None, | |
| encoder_outputs: Optional[Tuple[torch.FloatTensor]] = None, | |
| use_cache: Optional[bool] = None, | |
| **kwargs, | |
| ) -> Union[Tuple[torch.FloatTensor], Seq2SeqLMOutput]: | |
| kwargs_encoder = {argument: value for argument, value in kwargs.items() if not argument.startswith("decoder_")} | |
| kwargs_decoder = { | |
| argument[len("decoder_") :]: value for argument, value in kwargs.items() if argument.startswith("decoder_") | |
| } | |
| if encoder_outputs is None: | |
| if pixel_values is None: | |
| raise ValueError("You have to specify pixel_values") | |
| encoder_outputs = self.encoder( | |
| pixel_values=pixel_values, | |
| **kwargs_encoder, | |
| ) | |
| elif isinstance(encoder_outputs, tuple): | |
| encoder_outputs = BaseModelOutput(*encoder_outputs) | |
| encoder_hidden_states = encoder_outputs[0] | |
| # optionally project encoder_hidden_states | |
| if ( | |
| self.encoder.config.hidden_size != self.decoder.config.hidden_size | |
| and self.decoder.config.cross_attention_hidden_size is None | |
| ): | |
| encoder_hidden_states = self.enc_to_dec_proj(encoder_hidden_states) | |
| # else: | |
| encoder_attention_mask = None | |
| # Decode | |
| decoder_outputs = self.decoder( | |
| input_ids=decoder_input_ids, | |
| cache_position=decoder_cache_position, | |
| attention_mask=decoder_attention_mask, | |
| encoder_hidden_states=encoder_hidden_states, | |
| encoder_attention_mask=encoder_attention_mask, | |
| use_cache=use_cache, | |
| **kwargs_decoder, | |
| ) | |
| return Seq2SeqLMOutput( | |
| logits=decoder_outputs.logits, | |
| decoder_hidden_states=decoder_outputs.hidden_states, | |
| encoder_last_hidden_state=encoder_outputs.last_hidden_state | |
| ) | |
| def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): | |
| return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id) | |
| def prepare_inputs_for_generation( | |
| self, input_ids, past_key_values=None, attention_mask=None, use_cache=None, encoder_outputs=None, **kwargs | |
| ): | |
| decoder_inputs = self.decoder.prepare_inputs_for_generation(input_ids, past_key_values=past_key_values) | |
| decoder_attention_mask = decoder_inputs["attention_mask"] if "attention_mask" in decoder_inputs else None | |
| input_dict = { | |
| "attention_mask": attention_mask, | |
| "decoder_attention_mask": decoder_attention_mask, | |
| "decoder_input_ids": decoder_inputs["input_ids"], | |
| "encoder_outputs": encoder_outputs, | |
| "past_key_values": decoder_inputs["past_key_values"], | |
| "use_cache": use_cache, | |
| } | |
| return input_dict | |
| def resize_token_embeddings(self, *args, **kwargs): | |
| raise NotImplementedError( | |
| "Resizing the embedding layers via the VisionEncoderDecoderModel directly is not supported.Please use the" | |
| " respective methods of the wrapped decoder object (model.decoder.resize_token_embeddings(...))" | |
| ) | |
| def _reorder_cache(self, past_key_values, beam_idx): | |
| # apply decoder cache reordering here | |
| return self.decoder._reorder_cache(past_key_values, beam_idx) |