Text Recognition Model

CRNN-based text recognition model trained on ICDAR2003 for optical character recognition (OCR) tasks.

Model

Architecture: CRNN with CTC decoding
Backbone: ResNet34
Input size: (100, 420)
Framework: PyTorch

This model takes a cropped grayscale text image as input and predicts the corresponding character sequence.

Architecture Details

The model consists of three main components:

  • CNN backbone: ResNet34 feature extractor adapted for single-channel input images.
  • Sequence modeling: Multi-layer bidirectional GRU for contextual encoding along the text sequence.
  • Prediction head: Linear projection followed by log-softmax for CTC-based decoding.

Vocabulary

The recognition vocabulary includes digits and lowercase English letters:

- 0 1 2 3 4 5 6 7 8 9 a b c d e f g h i j k l m n o p q r s t u v w x y z

Intended Use

This model is designed for text recognition on cropped word images. It is suitable for OCR pipelines where text regions have already been detected and normalized before being passed to the recognizer.

Usage

import json
import torch
import torch.nn as nn
import timm
from huggingface_hub import hf_hub_download


class CRNNModel(nn.Module):
    def __init__(self, vocab_size, emb_dim=512, hidden_size=256, dropout_prob=0.2, num_layers=3, unfreeze_layer=3):
        super().__init__()

        # backbone CNN
        cnn_model = timm.create_model('resnet34', pretrained=True, in_chans=1)
        cnn_model_classifier_removal = list(cnn_model.children())[:-2]
        cnn_model_classifier_removal.append(nn.AdaptiveAvgPool2d((1, None)))
        cnn_model_standard_type = nn.Sequential(*cnn_model_classifier_removal)
        self.backbone = cnn_model_standard_type

        for param in self.backbone[-unfreeze_layer:].parameters():
            param.requires_grad = True

        self.linear_layer = nn.Sequential(
            nn.Linear(512, emb_dim),
            nn.ReLU(),
            nn.Dropout(dropout_prob)
        )

        self.rnn_layer = nn.GRU(
            input_size=emb_dim,
            hidden_size=hidden_size,
            bidirectional=True,
            batch_first=True,
            num_layers=num_layers,
            dropout=dropout_prob if num_layers > 1 else 0
        )

        self.layernorm = nn.LayerNorm(hidden_size * 2)

        self.output = nn.Sequential(
            nn.Linear(hidden_size * 2, vocab_size + 1),
            nn.LogSoftmax(dim=2)
        )

    @torch.autocast(device_type="cuda")
    def forward(self, x):
        x = self.backbone(x)
        x = x.permute(0, 3, 1, 2)
        x = x.view(x.size(0), x.size(1), -1)
        x = self.linear_layer(x)
        x, _ = self.rnn_layer(x)
        x = self.layernorm(x)
        x = self.output(x)
        x = x.permute(1, 0, 2)
        return x


idx_2_label = {
    1: '-', 2: '0', 3: '1', 4: '2', 5: '3', 6: '4', 7: '5', 8: '6', 9: '7', 10: '8',
    11: '9', 12: 'a', 13: 'b', 14: 'c', 15: 'd', 16: 'e', 17: 'f', 18: 'g', 19: 'h', 20: 'i',
    21: 'j', 22: 'k', 23: 'l', 24: 'm', 25: 'n', 26: 'o', 27: 'p', 28: 'q', 29: 'r', 30: 's',
    31: 't', 32: 'u', 33: 'v', 34: 'w', 35: 'x', 36: 'y', 37: 'z'
}


def load_from_hub(repo_id, device="cpu"):
    model_path = hf_hub_download(repo_id=repo_id, filename="pytorch_model.bin")
    config_path = hf_hub_download(repo_id=repo_id, filename="config.json")
    vocab_path = hf_hub_download(repo_id=repo_id, filename="idx_to_char.json")

    with open(config_path, "r", encoding="utf-8") as f:
        config = json.load(f)

    with open(vocab_path, "r", encoding="utf-8") as f:
        idx_to_char = json.load(f)

    # JSON may convert integer keys to strings
    if isinstance(idx_to_char, dict):
        idx_to_char = {int(k): v for k, v in idx_to_char.items()}

    model = CRNNModel(
        vocab_size=config["num_classes"]
    ).to(device)

    state_dict = torch.load(model_path, map_location=device)
    model.load_state_dict(state_dict)
    model.eval()

    return model, idx_to_char


device = "cuda" if torch.cuda.is_available() else "cpu"

text_recognition_model, idx_to_char = load_from_hub(
    "huytqvn/text-recognition-pipeline",
    device=device
)
Downloads last month
26
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support