Spaces:
Build error
Build error
| # Scene Text Recognition Model Hub | |
| # Copyright 2022 Darwin Bautista | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # https://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| import math | |
| from functools import partial | |
| from itertools import permutations | |
| from typing import Sequence, Any, Optional | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from torch import Tensor | |
| from pytorch_lightning.utilities.types import STEP_OUTPUT | |
| from timm.models.helpers import named_apply | |
| from strhub.models.base import CrossEntropySystem | |
| from strhub.models.utils import init_weights | |
| from .modules import DecoderLayer, Decoder, Encoder, TokenEmbedding | |
| import sys | |
| class PARSeq(CrossEntropySystem): | |
| def __init__(self, charset_train: str, charset_test: str, max_label_length: int, | |
| batch_size: int, lr: float, warmup_pct: float, weight_decay: float, | |
| img_size: Sequence[int], patch_size: Sequence[int], embed_dim: int, | |
| enc_num_heads: int, enc_mlp_ratio: int, enc_depth: int, | |
| dec_num_heads: int, dec_mlp_ratio: int, dec_depth: int, | |
| perm_num: int, perm_forward: bool, perm_mirrored: bool, | |
| decode_ar: bool, refine_iters: int, dropout: float, **kwargs: Any) -> None: | |
| super().__init__(charset_train, charset_test, batch_size, lr, warmup_pct, weight_decay) | |
| self.save_hyperparameters() | |
| self.max_label_length = max_label_length | |
| self.decode_ar = decode_ar | |
| self.refine_iters = refine_iters | |
| self.encoder = Encoder(img_size, patch_size, embed_dim=embed_dim, depth=enc_depth, num_heads=enc_num_heads, | |
| mlp_ratio=enc_mlp_ratio) | |
| decoder_layer = DecoderLayer(embed_dim, dec_num_heads, embed_dim * dec_mlp_ratio, dropout) | |
| self.decoder = Decoder(decoder_layer, num_layers=dec_depth, norm=nn.LayerNorm(embed_dim)) | |
| # Perm/attn mask stuff | |
| self.rng = np.random.default_rng() | |
| self.max_gen_perms = perm_num // 2 if perm_mirrored else perm_num | |
| self.perm_forward = perm_forward | |
| self.perm_mirrored = perm_mirrored | |
| # We don't predict <bos> nor <pad> | |
| self.head = nn.Linear(embed_dim, len(self.tokenizer) - 2) | |
| self.text_embed = TokenEmbedding(len(self.tokenizer), embed_dim) | |
| # +1 for <eos> | |
| self.pos_queries = nn.Parameter(torch.Tensor(1, max_label_length + 1, embed_dim)) | |
| self.dropout = nn.Dropout(p=dropout) | |
| # Encoder has its own init. | |
| named_apply(partial(init_weights, exclude=['encoder']), self) | |
| nn.init.trunc_normal_(self.pos_queries, std=.02) | |
| def no_weight_decay(self): | |
| param_names = {'text_embed.embedding.weight', 'pos_queries'} | |
| enc_param_names = {'encoder.' + n for n in self.encoder.no_weight_decay()} | |
| return param_names.union(enc_param_names) | |
| def encode(self, img: torch.Tensor): | |
| return self.encoder(img) | |
| def decode(self, tgt: torch.Tensor, memory: torch.Tensor, tgt_mask: Optional[Tensor] = None, | |
| tgt_padding_mask: Optional[Tensor] = None, tgt_query: Optional[Tensor] = None, | |
| tgt_query_mask: Optional[Tensor] = None): | |
| N, L = tgt.shape | |
| # <bos> stands for the null context. We only supply position information for characters after <bos>. | |
| null_ctx = self.text_embed(tgt[:, :1]) | |
| tgt_emb = self.pos_queries[:, :L - 1] + self.text_embed(tgt[:, 1:]) | |
| tgt_emb = self.dropout(torch.cat([null_ctx, tgt_emb], dim=1)) | |
| if tgt_query is None: | |
| tgt_query = self.pos_queries[:, :L].expand(N, -1, -1) | |
| tgt_query = self.dropout(tgt_query) | |
| return self.decoder(tgt_query, tgt_emb, memory, tgt_query_mask, tgt_mask, tgt_padding_mask) | |
| def forward(self, images: Tensor, max_length: Optional[int] = None) -> Tensor: | |
| testing = max_length is None | |
| max_length = self.max_label_length if max_length is None else min(max_length, self.max_label_length) | |
| bs = images.shape[0] | |
| # +1 for <eos> at end of sequence. | |
| num_steps = max_length + 1 | |
| # images shape: torch.Size([1, 3, 32, 128]) | |
| memory = self.encode(images) | |
| # memory shape: torch.Size([1, 128, 384]) | |
| # Query positions up to `num_steps` | |
| pos_queries = self.pos_queries[:, :num_steps].expand(bs, -1, -1) | |
| # pos_queries shape: torch.Size([1, 26, 384]) | |
| # self.decode_ar: True | |
| # sys.exit() | |
| # Special case for the forward permutation. Faster than using `generate_attn_masks()` | |
| tgt_mask = query_mask = torch.triu(torch.full((num_steps, num_steps), float('-inf'), device=self._device), 1) | |
| # num_steps: 26 | |
| if self.decode_ar: | |
| tgt_in = torch.full((bs, num_steps), self.pad_id, dtype=torch.long, device=self._device) | |
| tgt_in[:, 0] = self.bos_id | |
| logits = [] | |
| for i in range(num_steps): | |
| j = i + 1 # next token index | |
| # Efficient decoding: | |
| # Input the context up to the ith token. We use only one query (at position = i) at a time. | |
| # This works because of the lookahead masking effect of the canonical (forward) AR context. | |
| # Past tokens have no access to future tokens, hence are fixed once computed. | |
| tgt_out = self.decode(tgt_in[:, :j], memory, tgt_mask[:j, :j], tgt_query=pos_queries[:, i:j], | |
| tgt_query_mask=query_mask[i:j, :j]) | |
| # tgt_out shape: torch.Size([1, 1, 384]) | |
| # the next token probability is in the output's ith token position | |
| p_i = self.head(tgt_out) | |
| logits.append(p_i) | |
| if j < num_steps: | |
| # greedy decode. add the next token index to the target input | |
| tgt_in[:, j] = p_i.squeeze().argmax(-1) | |
| # Efficient batch decoding: If all output words have at least one EOS token, end decoding. | |
| if testing and (tgt_in == self.eos_id).any(dim=-1).all(): | |
| break | |
| logits = torch.cat(logits, dim=1) | |
| else: | |
| # No prior context, so input is just <bos>. We query all positions. | |
| tgt_in = torch.full((bs, 1), self.bos_id, dtype=torch.long, device=self._device) | |
| tgt_out = self.decode(tgt_in, memory, tgt_query=pos_queries) | |
| logits = self.head(tgt_out) | |
| # print("logits before: ", logits.shape) torch.Size([1, 6, 95]) | |
| # self.refine_iters: 1 | |
| if self.refine_iters: | |
| # For iterative refinement, we always use a 'cloze' mask. | |
| # We can derive it from the AR forward mask by unmasking the token context to the right. | |
| query_mask[torch.triu(torch.ones(num_steps, num_steps, dtype=torch.bool, device=self._device), 2)] = 0 | |
| bos = torch.full((bs, 1), self.bos_id, dtype=torch.long, device=self._device) | |
| for i in range(self.refine_iters): | |
| # Prior context is the previous output. | |
| tgt_in = torch.cat([bos, logits[:, :-1].argmax(-1)], dim=1) | |
| # print("tgt_in 2: ", tgt_in.shape) torch.Size([1, 6]) | |
| tgt_padding_mask = ((tgt_in == self.eos_id).cumsum(-1) > 0) # mask tokens beyond the first EOS token. | |
| tgt_out = self.decode(tgt_in, memory, tgt_mask, tgt_padding_mask, | |
| tgt_query=pos_queries, tgt_query_mask=query_mask[:, :tgt_in.shape[1]]) | |
| # print("tgt_out 2: ", tgt_out.shape) torch.Size([1, 26, 384]) | |
| logits = self.head(tgt_out) | |
| return logits | |
| def gen_tgt_perms(self, tgt): | |
| """Generate shared permutations for the whole batch. | |
| This works because the same attention mask can be used for the shorter sequences | |
| because of the padding mask. | |
| """ | |
| # We don't permute the position of BOS, we permute EOS separately | |
| max_num_chars = tgt.shape[1] - 2 | |
| # Special handling for 1-character sequences | |
| if max_num_chars == 1: | |
| return torch.arange(3, device=self._device).unsqueeze(0) | |
| perms = [torch.arange(max_num_chars, device=self._device)] if self.perm_forward else [] | |
| # Additional permutations if needed | |
| max_perms = math.factorial(max_num_chars) | |
| if self.perm_mirrored: | |
| max_perms //= 2 | |
| num_gen_perms = min(self.max_gen_perms, max_perms) | |
| # For 4-char sequences and shorter, we generate all permutations and sample from the pool to avoid collisions | |
| # Note that this code path might NEVER get executed since the labels in a mini-batch typically exceed 4 chars. | |
| if max_num_chars < 5: | |
| # Pool of permutations to sample from. We only need the first half (if complementary option is selected) | |
| # Special handling for max_num_chars == 4 which correctly divides the pool into the flipped halves | |
| if max_num_chars == 4 and self.perm_mirrored: | |
| selector = [0, 3, 4, 6, 9, 10, 12, 16, 17, 18, 19, 21] | |
| else: | |
| selector = list(range(max_perms)) | |
| perm_pool = torch.as_tensor(list(permutations(range(max_num_chars), max_num_chars)), device=self._device)[selector] | |
| # If the forward permutation is always selected, no need to add it to the pool for sampling | |
| if self.perm_forward: | |
| perm_pool = perm_pool[1:] | |
| perms = torch.stack(perms) | |
| if len(perm_pool): | |
| i = self.rng.choice(len(perm_pool), size=num_gen_perms - len(perms), replace=False) | |
| perms = torch.cat([perms, perm_pool[i]]) | |
| else: | |
| perms.extend([torch.randperm(max_num_chars, device=self._device) for _ in range(num_gen_perms - len(perms))]) | |
| perms = torch.stack(perms) | |
| if self.perm_mirrored: | |
| # Add complementary pairs | |
| comp = perms.flip(-1) | |
| # Stack in such a way that the pairs are next to each other. | |
| perms = torch.stack([perms, comp]).transpose(0, 1).reshape(-1, max_num_chars) | |
| # NOTE: | |
| # The only meaningful way of permuting the EOS position is by moving it one character position at a time. | |
| # However, since the number of permutations = T! and number of EOS positions = T + 1, the number of possible EOS | |
| # positions will always be much less than the number of permutations (unless a low perm_num is set). | |
| # Thus, it would be simpler to just train EOS using the full and null contexts rather than trying to evenly | |
| # distribute it across the chosen number of permutations. | |
| # Add position indices of BOS and EOS | |
| bos_idx = perms.new_zeros((len(perms), 1)) | |
| eos_idx = perms.new_full((len(perms), 1), max_num_chars + 1) | |
| perms = torch.cat([bos_idx, perms + 1, eos_idx], dim=1) | |
| # Special handling for the reverse direction. This does two things: | |
| # 1. Reverse context for the characters | |
| # 2. Null context for [EOS] (required for learning to predict [EOS] in NAR mode) | |
| if len(perms) > 1: | |
| perms[1, 1:] = max_num_chars + 1 - torch.arange(max_num_chars + 1, device=self._device) | |
| return perms | |
| def generate_attn_masks(self, perm): | |
| """Generate attention masks given a sequence permutation (includes pos. for bos and eos tokens) | |
| :param perm: the permutation sequence. i = 0 is always the BOS | |
| :return: lookahead attention masks | |
| """ | |
| sz = perm.shape[0] | |
| mask = torch.zeros((sz, sz), device=self._device) | |
| for i in range(sz): | |
| query_idx = perm[i] | |
| masked_keys = perm[i + 1:] | |
| mask[query_idx, masked_keys] = float('-inf') | |
| content_mask = mask[:-1, :-1].clone() | |
| mask[torch.eye(sz, dtype=torch.bool, device=self._device)] = float('-inf') # mask "self" | |
| query_mask = mask[1:, :-1] | |
| return content_mask, query_mask | |
| def training_step(self, batch, batch_idx) -> STEP_OUTPUT: | |
| images, labels = batch | |
| tgt = self.tokenizer.encode(labels, self._device) | |
| # Encode the source sequence (i.e. the image codes) | |
| memory = self.encode(images) | |
| # Prepare the target sequences (input and output) | |
| tgt_perms = self.gen_tgt_perms(tgt) | |
| tgt_in = tgt[:, :-1] | |
| tgt_out = tgt[:, 1:] | |
| # The [EOS] token is not depended upon by any other token in any permutation ordering | |
| tgt_padding_mask = (tgt_in == self.pad_id) | (tgt_in == self.eos_id) | |
| loss = 0 | |
| loss_numel = 0 | |
| n = (tgt_out != self.pad_id).sum().item() | |
| for i, perm in enumerate(tgt_perms): | |
| tgt_mask, query_mask = self.generate_attn_masks(perm) | |
| out = self.decode(tgt_in, memory, tgt_mask, tgt_padding_mask, tgt_query_mask=query_mask) | |
| logits = self.head(out).flatten(end_dim=1) | |
| loss += n * F.cross_entropy(logits, tgt_out.flatten(), ignore_index=self.pad_id) | |
| loss_numel += n | |
| # After the second iteration (i.e. done with canonical and reverse orderings), | |
| # remove the [EOS] tokens for the succeeding perms | |
| if i == 1: | |
| tgt_out = torch.where(tgt_out == self.eos_id, self.pad_id, tgt_out) | |
| n = (tgt_out != self.pad_id).sum().item() | |
| loss /= loss_numel | |
| self.log('loss', loss) | |
| return loss | |