Spaces:
Sleeping
Sleeping
| # Copyright (C) 2021-2024, Mindee. | |
| # This program is licensed under the Apache License 2.0. | |
| # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details. | |
| from copy import deepcopy | |
| from typing import Any, Dict, List, Optional, Tuple | |
| import tensorflow as tf | |
| from tensorflow.keras import Model, Sequential, layers | |
| from doctr.datasets import VOCABS | |
| from doctr.utils.repr import NestedObject | |
| from ...classification import resnet31 | |
| from ...utils.tensorflow import _bf16_to_float32, load_pretrained_params | |
| from ..core import RecognitionModel, RecognitionPostProcessor | |
| __all__ = ["SAR", "sar_resnet31"] | |
| default_cfgs: Dict[str, Dict[str, Any]] = { | |
| "sar_resnet31": { | |
| "mean": (0.694, 0.695, 0.693), | |
| "std": (0.299, 0.296, 0.301), | |
| "input_shape": (32, 128, 3), | |
| "vocab": VOCABS["french"], | |
| "url": "https://doctr-static.mindee.com/models?id=v0.6.0/sar_resnet31-c41e32a5.zip&src=0", | |
| }, | |
| } | |
| class SAREncoder(layers.Layer, NestedObject): | |
| """Implements encoder module of the SAR model | |
| Args: | |
| ---- | |
| rnn_units: number of hidden rnn units | |
| dropout_prob: dropout probability | |
| """ | |
| def __init__(self, rnn_units: int, dropout_prob: float = 0.0) -> None: | |
| super().__init__() | |
| self.rnn = Sequential([ | |
| layers.LSTM(units=rnn_units, return_sequences=True, recurrent_dropout=dropout_prob), | |
| layers.LSTM(units=rnn_units, return_sequences=False, recurrent_dropout=dropout_prob), | |
| ]) | |
| def call( | |
| self, | |
| x: tf.Tensor, | |
| **kwargs: Any, | |
| ) -> tf.Tensor: | |
| # (N, C) | |
| return self.rnn(x, **kwargs) | |
| class AttentionModule(layers.Layer, NestedObject): | |
| """Implements attention module of the SAR model | |
| Args: | |
| ---- | |
| attention_units: number of hidden attention units | |
| """ | |
| def __init__(self, attention_units: int) -> None: | |
| super().__init__() | |
| self.hidden_state_projector = layers.Conv2D( | |
| attention_units, | |
| 1, | |
| strides=1, | |
| use_bias=False, | |
| padding="same", | |
| kernel_initializer="he_normal", | |
| ) | |
| self.features_projector = layers.Conv2D( | |
| attention_units, | |
| 3, | |
| strides=1, | |
| use_bias=True, | |
| padding="same", | |
| kernel_initializer="he_normal", | |
| ) | |
| self.attention_projector = layers.Conv2D( | |
| 1, | |
| 1, | |
| strides=1, | |
| use_bias=False, | |
| padding="same", | |
| kernel_initializer="he_normal", | |
| ) | |
| self.flatten = layers.Flatten() | |
| def call( | |
| self, | |
| features: tf.Tensor, | |
| hidden_state: tf.Tensor, | |
| **kwargs: Any, | |
| ) -> tf.Tensor: | |
| [H, W] = features.get_shape().as_list()[1:3] | |
| # shape (N, H, W, vgg_units) -> (N, H, W, attention_units) | |
| features_projection = self.features_projector(features, **kwargs) | |
| # shape (N, 1, 1, rnn_units) -> (N, 1, 1, attention_units) | |
| hidden_state = tf.expand_dims(tf.expand_dims(hidden_state, axis=1), axis=1) | |
| hidden_state_projection = self.hidden_state_projector(hidden_state, **kwargs) | |
| projection = tf.math.tanh(hidden_state_projection + features_projection) | |
| # shape (N, H, W, attention_units) -> (N, H, W, 1) | |
| attention = self.attention_projector(projection, **kwargs) | |
| # shape (N, H, W, 1) -> (N, H * W) | |
| attention = self.flatten(attention) | |
| attention = tf.nn.softmax(attention) | |
| # shape (N, H * W) -> (N, H, W, 1) | |
| attention_map = tf.reshape(attention, [-1, H, W, 1]) | |
| glimpse = tf.math.multiply(features, attention_map) | |
| # shape (N, H * W) -> (N, C) | |
| return tf.reduce_sum(glimpse, axis=[1, 2]) | |
| class SARDecoder(layers.Layer, NestedObject): | |
| """Implements decoder module of the SAR model | |
| Args: | |
| ---- | |
| rnn_units: number of hidden units in recurrent cells | |
| max_length: maximum length of a sequence | |
| vocab_size: number of classes in the model alphabet | |
| embedding_units: number of hidden embedding units | |
| attention_units: number of hidden attention units | |
| num_decoder_cells: number of LSTMCell layers to stack | |
| dropout_prob: dropout probability | |
| """ | |
| def __init__( | |
| self, | |
| rnn_units: int, | |
| max_length: int, | |
| vocab_size: int, | |
| embedding_units: int, | |
| attention_units: int, | |
| num_decoder_cells: int = 2, | |
| dropout_prob: float = 0.0, | |
| ) -> None: | |
| super().__init__() | |
| self.vocab_size = vocab_size | |
| self.max_length = max_length | |
| self.embed = layers.Dense(embedding_units, use_bias=False) | |
| self.embed_tgt = layers.Embedding(embedding_units, self.vocab_size + 1) | |
| self.lstm_cells = layers.StackedRNNCells([ | |
| layers.LSTMCell(rnn_units, implementation=1) for _ in range(num_decoder_cells) | |
| ]) | |
| self.attention_module = AttentionModule(attention_units) | |
| self.output_dense = layers.Dense(self.vocab_size + 1, use_bias=True) | |
| self.dropout = layers.Dropout(dropout_prob) | |
| def call( | |
| self, | |
| features: tf.Tensor, | |
| holistic: tf.Tensor, | |
| gt: Optional[tf.Tensor] = None, | |
| **kwargs: Any, | |
| ) -> tf.Tensor: | |
| if gt is not None: | |
| gt_embedding = self.embed_tgt(gt, **kwargs) | |
| logits_list: List[tf.Tensor] = [] | |
| for t in range(self.max_length + 1): # 32 | |
| if t == 0: | |
| # step to init the first states of the LSTMCell | |
| states = self.lstm_cells.get_initial_state( | |
| inputs=None, batch_size=features.shape[0], dtype=features.dtype | |
| ) | |
| prev_symbol = holistic | |
| elif t == 1: | |
| # step to init a 'blank' sequence of length vocab_size + 1 filled with zeros | |
| # (N, vocab_size + 1) --> (N, embedding_units) | |
| prev_symbol = tf.zeros([features.shape[0], self.vocab_size + 1], dtype=features.dtype) | |
| prev_symbol = self.embed(prev_symbol, **kwargs) | |
| else: | |
| if gt is not None and kwargs.get("training", False): | |
| # (N, embedding_units) -2 because of <bos> and <eos> (same) | |
| prev_symbol = self.embed(gt_embedding[:, t - 2], **kwargs) | |
| else: | |
| # -1 to start at timestep where prev_symbol was initialized | |
| index = tf.argmax(logits_list[t - 1], axis=-1) | |
| # update prev_symbol with ones at the index of the previous logit vector | |
| prev_symbol = self.embed(self.embed_tgt(index, **kwargs), **kwargs) | |
| # (N, C), (N, C) take the last hidden state and cell state from current timestep | |
| _, states = self.lstm_cells(prev_symbol, states, **kwargs) | |
| # states = (hidden_state, cell_state) | |
| hidden_state = states[0][0] | |
| # (N, H, W, C), (N, C) --> (N, C) | |
| glimpse = self.attention_module(features, hidden_state, **kwargs) | |
| # (N, C), (N, C) --> (N, 2 * C) | |
| logits = tf.concat([hidden_state, glimpse], axis=1) | |
| logits = self.dropout(logits, **kwargs) | |
| # (N, vocab_size + 1) | |
| logits_list.append(self.output_dense(logits, **kwargs)) | |
| # (max_length + 1, N, vocab_size + 1) --> (N, max_length + 1, vocab_size + 1) | |
| return tf.transpose(tf.stack(logits_list[1:]), (1, 0, 2)) | |
| class SAR(Model, RecognitionModel): | |
| """Implements a SAR architecture as described in `"Show, Attend and Read:A Simple and Strong Baseline for | |
| Irregular Text Recognition" <https://arxiv.org/pdf/1811.00751.pdf>`_. | |
| Args: | |
| ---- | |
| feature_extractor: the backbone serving as feature extractor | |
| vocab: vocabulary used for encoding | |
| rnn_units: number of hidden units in both encoder and decoder LSTM | |
| embedding_units: number of embedding units | |
| attention_units: number of hidden units in attention module | |
| max_length: maximum word length handled by the model | |
| num_decoder_cells: number of LSTMCell layers to stack | |
| dropout_prob: dropout probability for the encoder and decoder | |
| exportable: onnx exportable returns only logits | |
| cfg: dictionary containing information about the model | |
| """ | |
| _children_names: List[str] = ["feat_extractor", "encoder", "decoder", "postprocessor"] | |
| def __init__( | |
| self, | |
| feature_extractor, | |
| vocab: str, | |
| rnn_units: int = 512, | |
| embedding_units: int = 512, | |
| attention_units: int = 512, | |
| max_length: int = 30, | |
| num_decoder_cells: int = 2, | |
| dropout_prob: float = 0.0, | |
| exportable: bool = False, | |
| cfg: Optional[Dict[str, Any]] = None, | |
| ) -> None: | |
| super().__init__() | |
| self.vocab = vocab | |
| self.exportable = exportable | |
| self.cfg = cfg | |
| self.max_length = max_length + 1 # Add 1 timestep for EOS after the longest word | |
| self.feat_extractor = feature_extractor | |
| self.encoder = SAREncoder(rnn_units, dropout_prob) | |
| self.decoder = SARDecoder( | |
| rnn_units, | |
| self.max_length, | |
| len(vocab), | |
| embedding_units, | |
| attention_units, | |
| num_decoder_cells, | |
| dropout_prob, | |
| ) | |
| self.postprocessor = SARPostProcessor(vocab=vocab) | |
| def compute_loss( | |
| model_output: tf.Tensor, | |
| gt: tf.Tensor, | |
| seq_len: tf.Tensor, | |
| ) -> tf.Tensor: | |
| """Compute categorical cross-entropy loss for the model. | |
| Sequences are masked after the EOS character. | |
| Args: | |
| ---- | |
| gt: the encoded tensor with gt labels | |
| model_output: predicted logits of the model | |
| seq_len: lengths of each gt word inside the batch | |
| Returns: | |
| ------- | |
| The loss of the model on the batch | |
| """ | |
| # Input length : number of timesteps | |
| input_len = tf.shape(model_output)[1] | |
| # Add one for additional <eos> token | |
| seq_len = seq_len + 1 | |
| # One-hot gt labels | |
| oh_gt = tf.one_hot(gt, depth=model_output.shape[2]) | |
| # Compute loss | |
| cce = tf.nn.softmax_cross_entropy_with_logits(oh_gt, model_output) | |
| # Compute mask | |
| mask_values = tf.zeros_like(cce) | |
| mask_2d = tf.sequence_mask(seq_len, input_len) | |
| masked_loss = tf.where(mask_2d, cce, mask_values) | |
| ce_loss = tf.math.divide(tf.reduce_sum(masked_loss, axis=1), tf.cast(seq_len, model_output.dtype)) | |
| return tf.expand_dims(ce_loss, axis=1) | |
| def call( | |
| self, | |
| x: tf.Tensor, | |
| target: Optional[List[str]] = None, | |
| return_model_output: bool = False, | |
| return_preds: bool = False, | |
| **kwargs: Any, | |
| ) -> Dict[str, Any]: | |
| features = self.feat_extractor(x, **kwargs) | |
| # vertical max pooling --> (N, C, W) | |
| pooled_features = tf.reduce_max(features, axis=1) | |
| # holistic (N, C) | |
| encoded = self.encoder(pooled_features, **kwargs) | |
| if target is not None: | |
| gt, seq_len = self.build_target(target) | |
| seq_len = tf.cast(seq_len, tf.int32) | |
| if kwargs.get("training", False) and target is None: | |
| raise ValueError("Need to provide labels during training for teacher forcing") | |
| decoded_features = _bf16_to_float32( | |
| self.decoder(features, encoded, gt=None if target is None else gt, **kwargs) | |
| ) | |
| out: Dict[str, tf.Tensor] = {} | |
| if self.exportable: | |
| out["logits"] = decoded_features | |
| return out | |
| if return_model_output: | |
| out["out_map"] = decoded_features | |
| if target is None or return_preds: | |
| # Post-process boxes | |
| out["preds"] = self.postprocessor(decoded_features) | |
| if target is not None: | |
| out["loss"] = self.compute_loss(decoded_features, gt, seq_len) | |
| return out | |
| class SARPostProcessor(RecognitionPostProcessor): | |
| """Post processor for SAR architectures | |
| Args: | |
| ---- | |
| vocab: string containing the ordered sequence of supported characters | |
| """ | |
| def __call__( | |
| self, | |
| logits: tf.Tensor, | |
| ) -> List[Tuple[str, float]]: | |
| # compute pred with argmax for attention models | |
| out_idxs = tf.math.argmax(logits, axis=2) | |
| # N x L | |
| probs = tf.gather(tf.nn.softmax(logits, axis=-1), out_idxs, axis=-1, batch_dims=2) | |
| # Take the minimum confidence of the sequence | |
| probs = tf.math.reduce_min(probs, axis=1) | |
| # decode raw output of the model with tf_label_to_idx | |
| out_idxs = tf.cast(out_idxs, dtype="int32") | |
| embedding = tf.constant(self._embedding, dtype=tf.string) | |
| decoded_strings_pred = tf.strings.reduce_join(inputs=tf.nn.embedding_lookup(embedding, out_idxs), axis=-1) | |
| decoded_strings_pred = tf.strings.split(decoded_strings_pred, "<eos>") | |
| decoded_strings_pred = tf.sparse.to_dense(decoded_strings_pred.to_sparse(), default_value="not valid")[:, 0] | |
| word_values = [word.decode() for word in decoded_strings_pred.numpy().tolist()] | |
| return list(zip(word_values, probs.numpy().clip(0, 1).tolist())) | |
| def _sar( | |
| arch: str, | |
| pretrained: bool, | |
| backbone_fn, | |
| pretrained_backbone: bool = True, | |
| input_shape: Optional[Tuple[int, int, int]] = None, | |
| **kwargs: Any, | |
| ) -> SAR: | |
| pretrained_backbone = pretrained_backbone and not pretrained | |
| # Patch the config | |
| _cfg = deepcopy(default_cfgs[arch]) | |
| _cfg["input_shape"] = input_shape or _cfg["input_shape"] | |
| _cfg["vocab"] = kwargs.get("vocab", _cfg["vocab"]) | |
| # Feature extractor | |
| feat_extractor = backbone_fn( | |
| pretrained=pretrained_backbone, | |
| input_shape=_cfg["input_shape"], | |
| include_top=False, | |
| ) | |
| kwargs["vocab"] = _cfg["vocab"] | |
| # Build the model | |
| model = SAR(feat_extractor, cfg=_cfg, **kwargs) | |
| # Load pretrained parameters | |
| if pretrained: | |
| load_pretrained_params(model, default_cfgs[arch]["url"]) | |
| return model | |
| def sar_resnet31(pretrained: bool = False, **kwargs: Any) -> SAR: | |
| """SAR with a resnet-31 feature extractor as described in `"Show, Attend and Read:A Simple and Strong | |
| Baseline for Irregular Text Recognition" <https://arxiv.org/pdf/1811.00751.pdf>`_. | |
| >>> import tensorflow as tf | |
| >>> from doctr.models import sar_resnet31 | |
| >>> model = sar_resnet31(pretrained=False) | |
| >>> input_tensor = tf.random.uniform(shape=[1, 64, 256, 3], maxval=1, dtype=tf.float32) | |
| >>> out = model(input_tensor) | |
| Args: | |
| ---- | |
| pretrained (bool): If True, returns a model pre-trained on our text recognition dataset | |
| **kwargs: keyword arguments of the SAR architecture | |
| Returns: | |
| ------- | |
| text recognition architecture | |
| """ | |
| return _sar("sar_resnet31", pretrained, resnet31, **kwargs) | |