{
"cells": [
{
"attachments": {},
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "view-in-github"
},
"source": [
""
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {
"id": "XZTfJ2fy7Bxu"
},
"source": [
"## Fine-tune TrOCR on the IAM Handwriting Database\n",
"\n",
"In this notebook, we are going to fine-tune a pre-trained TrOCR model on the [IAM Handwriting Database](https://fki.tic.heia-fr.ch/databases/iam-handwriting-database), a collection of annotated images of handwritten text.\n",
"\n",
"We will do this using the new `VisionEncoderDecoderModel` class, which can be used to combine any image Transformer encoder (such as ViT, BEiT) with any text Transformer as decoder (such as BERT, RoBERTa, GPT-2). TrOCR is an instance of this, as it has an encoder-decoder architecture, with the weights of the encoder initialized from a pre-trained BEiT, and the weights of the decoder initialized from a pre-trained RoBERTa. The weights of the cross-attention layer were randomly initialized, before the authors pre-trained the model further on millions of (partially synthetic) annotated images of handwritten text. \n",
"\n",
"This figure gives a good overview of the model (from the original paper):\n",
"\n",
"\n",
"\n",
"* TrOCR paper: https://arxiv.org/abs/2109.10282\n",
"* TrOCR documentation: https://huggingface.co/transformers/master/model_doc/trocr.html\n",
"\n",
"\n",
"Note that Patrick also wrote a very good [blog post](https://huggingface.co/blog/warm-starting-encoder-decoder) on warm-starting encoder-decoder models (which is what the TrOCR authors did). This blog post was very helpful for me to create this notebook. \n",
"\n",
"We will fine-tune the model using the Seq2SeqTrainer, which is a subclass of the 🤗 Trainer that lets you compute generative metrics such as BLEU, ROUGE, etc by doing generation (i.e. calling the `generate` method) inside the evaluation loop.\n",
"\n",
"\n",
"\n",
"## Set-up environment\n",
"\n",
"First, let's install the required libraries:\n",
"* Transformers (for the TrOCR model)\n",
"* Datasets & Jiwer (for the evaluation metric)\n",
"\n",
"We will not be using HuggingFace Datasets in this notebook for data preprocessing, we will just create a good old basic PyTorch Dataset."
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"id": "pkSzlRJq68tH"
},
"outputs": [],
"source": [
"!pip install -q transformers"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"id": "a8eZ6PWTHriw"
},
"outputs": [],
"source": [
"!pip install -q datasets jiwer"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {
"id": "lsTaPrDR7My2"
},
"source": [
"## Prepare data\n",
"\n",
"We first download the data. Here, I'm just using the IAM test set, as this was released by the TrOCR authors in the unilm repository. It can be downloaded from [this page](https://github.com/microsoft/unilm/tree/master/trocr). \n",
"\n",
"Let's make a [regular PyTorch dataset](https://pytorch.org/tutorials/beginner/basics/data_tutorial.html). We first create a Pandas dataframe with 2 columns. Each row consists of the file name of an image, and the corresponding text."
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 204
},
"id": "KkHqJw-W9Abl",
"outputId": "0db0a26a-4749-4b28-cd6f-fb355ad6aaf5"
},
"outputs": [
{
"data": {
"text/html": [
"
| \n", " | file_name | \n", "text | \n", "
|---|---|---|
| 0 | \n", "sai1 bian1 da3 wang3.png | \n", "sai1 bian1 da3 wang3 | \n", "
| 0 | \n", "que2 ya2 ba1.png | \n", "que2 ya2 ba1 | \n", "
| 0 | \n", "huang2 tong3.png | \n", "huang2 tong3 | \n", "
| 0 | \n", "do1.png | \n", "do1 | \n", "
| 0 | \n", "mao2 gou3.png | \n", "mao2 gou3 | \n", "