Spaces:
Running
Running
| from torch import nn | |
| import torch | |
| from typing import Optional, Tuple, Union | |
| import collections | |
| import math | |
| from transformers import DonutSwinPreTrainedModel | |
| from transformers.models.donut.modeling_donut_swin import DonutSwinPatchEmbeddings, DonutSwinEmbeddings, DonutSwinModel, \ | |
| DonutSwinEncoder | |
| from surya.model.ordering.config import VariableDonutSwinConfig | |
| class VariableDonutSwinEmbeddings(DonutSwinEmbeddings): | |
| """ | |
| Construct the patch and position embeddings. Optionally, also the mask token. | |
| """ | |
| def __init__(self, config, use_mask_token=False, **kwargs): | |
| super().__init__(config, use_mask_token) | |
| self.patch_embeddings = DonutSwinPatchEmbeddings(config) | |
| num_patches = self.patch_embeddings.num_patches | |
| self.patch_grid = self.patch_embeddings.grid_size | |
| self.mask_token = nn.Parameter(torch.zeros(1, 1, config.embed_dim)) if use_mask_token else None | |
| self.position_embeddings = None | |
| if config.use_absolute_embeddings: | |
| self.position_embeddings = nn.Parameter(torch.zeros(1, num_patches + 1, config.embed_dim)) | |
| self.row_embeddings = None | |
| self.column_embeddings = None | |
| if config.use_2d_embeddings: | |
| self.row_embeddings = nn.Parameter(torch.zeros(1, self.patch_grid[0] + 1, config.embed_dim)) | |
| self.column_embeddings = nn.Parameter(torch.zeros(1, self.patch_grid[1] + 1, config.embed_dim)) | |
| self.norm = nn.LayerNorm(config.embed_dim) | |
| self.dropout = nn.Dropout(config.hidden_dropout_prob) | |
| def forward( | |
| self, pixel_values: Optional[torch.FloatTensor], bool_masked_pos: Optional[torch.BoolTensor] = None, **kwargs | |
| ) -> Tuple[torch.Tensor]: | |
| embeddings, output_dimensions = self.patch_embeddings(pixel_values) | |
| # Layernorm across the last dimension (each patch is a single row) | |
| embeddings = self.norm(embeddings) | |
| batch_size, seq_len, embed_dim = embeddings.size() | |
| if bool_masked_pos is not None: | |
| mask_tokens = self.mask_token.expand(batch_size, seq_len, -1) | |
| # replace the masked visual tokens by mask_tokens | |
| mask = bool_masked_pos.unsqueeze(-1).type_as(mask_tokens) | |
| embeddings = embeddings * (1.0 - mask) + mask_tokens * mask | |
| if self.position_embeddings is not None: | |
| embeddings = embeddings + self.position_embeddings[:, :seq_len, :] | |
| if self.row_embeddings is not None and self.column_embeddings is not None: | |
| # Repeat the x position embeddings across the y axis like 0, 1, 2, 3, 0, 1, 2, 3, ... | |
| row_embeddings = self.row_embeddings[:, :output_dimensions[0], :].repeat_interleave(output_dimensions[1], dim=1) | |
| column_embeddings = self.column_embeddings[:, :output_dimensions[1], :].repeat(1, output_dimensions[0], 1) | |
| embeddings = embeddings + row_embeddings + column_embeddings | |
| embeddings = self.dropout(embeddings) | |
| return embeddings, output_dimensions | |
| class VariableDonutSwinModel(DonutSwinModel): | |
| config_class = VariableDonutSwinConfig | |
| def __init__(self, config, add_pooling_layer=True, use_mask_token=False, **kwargs): | |
| super().__init__(config) | |
| self.config = config | |
| self.num_layers = len(config.depths) | |
| self.num_features = int(config.embed_dim * 2 ** (self.num_layers - 1)) | |
| self.embeddings = VariableDonutSwinEmbeddings(config, use_mask_token=use_mask_token) | |
| self.encoder = DonutSwinEncoder(config, self.embeddings.patch_grid) | |
| self.pooler = nn.AdaptiveAvgPool1d(1) if add_pooling_layer else None | |
| # Initialize weights and apply final processing | |
| self.post_init() |