diff --git "a/train_ocr.ipynb" "b/train_ocr.ipynb"
new file mode 100644--- /dev/null
+++ "b/train_ocr.ipynb"
@@ -0,0 +1,884 @@
+{
+ "cells": [
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "view-in-github"
+ },
+ "source": [
+ ""
+ ]
+ },
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "XZTfJ2fy7Bxu"
+ },
+ "source": [
+ "## Fine-tune TrOCR on the IAM Handwriting Database\n",
+ "\n",
+ "In this notebook, we are going to fine-tune a pre-trained TrOCR model on the [IAM Handwriting Database](https://fki.tic.heia-fr.ch/databases/iam-handwriting-database), a collection of annotated images of handwritten text.\n",
+ "\n",
+ "We will do this using the new `VisionEncoderDecoderModel` class, which can be used to combine any image Transformer encoder (such as ViT, BEiT) with any text Transformer as decoder (such as BERT, RoBERTa, GPT-2). TrOCR is an instance of this, as it has an encoder-decoder architecture, with the weights of the encoder initialized from a pre-trained BEiT, and the weights of the decoder initialized from a pre-trained RoBERTa. The weights of the cross-attention layer were randomly initialized, before the authors pre-trained the model further on millions of (partially synthetic) annotated images of handwritten text. \n",
+ "\n",
+ "This figure gives a good overview of the model (from the original paper):\n",
+ "\n",
+ "\n",
+ "\n",
+ "* TrOCR paper: https://arxiv.org/abs/2109.10282\n",
+ "* TrOCR documentation: https://huggingface.co/transformers/master/model_doc/trocr.html\n",
+ "\n",
+ "\n",
+ "Note that Patrick also wrote a very good [blog post](https://huggingface.co/blog/warm-starting-encoder-decoder) on warm-starting encoder-decoder models (which is what the TrOCR authors did). This blog post was very helpful for me to create this notebook. \n",
+ "\n",
+ "We will fine-tune the model using the Seq2SeqTrainer, which is a subclass of the 🤗 Trainer that lets you compute generative metrics such as BLEU, ROUGE, etc by doing generation (i.e. calling the `generate` method) inside the evaluation loop.\n",
+ "\n",
+ "\n",
+ "\n",
+ "## Set-up environment\n",
+ "\n",
+ "First, let's install the required libraries:\n",
+ "* Transformers (for the TrOCR model)\n",
+ "* Datasets & Jiwer (for the evaluation metric)\n",
+ "\n",
+ "We will not be using HuggingFace Datasets in this notebook for data preprocessing, we will just create a good old basic PyTorch Dataset."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "metadata": {
+ "id": "pkSzlRJq68tH"
+ },
+ "outputs": [],
+ "source": [
+ "!pip install -q transformers"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "metadata": {
+ "id": "a8eZ6PWTHriw"
+ },
+ "outputs": [],
+ "source": [
+ "!pip install -q datasets jiwer"
+ ]
+ },
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "lsTaPrDR7My2"
+ },
+ "source": [
+ "## Prepare data\n",
+ "\n",
+ "We first download the data. Here, I'm just using the IAM test set, as this was released by the TrOCR authors in the unilm repository. It can be downloaded from [this page](https://github.com/microsoft/unilm/tree/master/trocr). \n",
+ "\n",
+ "Let's make a [regular PyTorch dataset](https://pytorch.org/tutorials/beginner/basics/data_tutorial.html). We first create a Pandas dataframe with 2 columns. Each row consists of the file name of an image, and the corresponding text."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 204
+ },
+ "id": "KkHqJw-W9Abl",
+ "outputId": "0db0a26a-4749-4b28-cd6f-fb355ad6aaf5"
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "
| \n", + " | file_name | \n", + "text | \n", + "
|---|---|---|
| 0 | \n", + "sai1 bian1 da3 wang3.png | \n", + "sai1 bian1 da3 wang3 | \n", + "
| 0 | \n", + "que2 ya2 ba1.png | \n", + "que2 ya2 ba1 | \n", + "
| 0 | \n", + "huang2 tong3.png | \n", + "huang2 tong3 | \n", + "
| 0 | \n", + "do1.png | \n", + "do1 | \n", + "
| 0 | \n", + "mao2 gou3.png | \n", + "mao2 gou3 | \n", + "