File size: 5,068 Bytes
2411029
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
import json
import re
from dataclasses import dataclass
from typing import Dict, List, Tuple

import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
from datasets import load_dataset
from PIL import Image


# -------------------------
# Text + Vocab utilities
# -------------------------

def normalize_text(s: str) -> str:
    """

    Basic cleanup. Keep it conservative so we don't ruin labels.

    - collapse multiple spaces

    - strip ends

    """
    s = s.replace("\t", " ")
    s = re.sub(r"\s+", " ", s)
    return s.strip()


def build_vocab(texts: List[str]) -> Dict[str, int]:
    """

    Build char-level vocab for CTC.

    index 0 reserved for CTC blank.

    """
    charset = set()
    for t in texts:
        t = normalize_text(t)
        for ch in t:
            charset.add(ch)

    chars = sorted(list(charset))
    stoi = {"<BLANK>": 0}
    for i, ch in enumerate(chars, start=1):
        stoi[ch] = i
    return stoi


def encode_text(text: str, stoi: Dict[str, int]) -> List[int]:
    text = normalize_text(text)
    return [stoi[ch] for ch in text if ch in stoi and ch != "<BLANK>"]


def save_vocab(stoi: Dict[str, int], path: str = "models/vocab.json") -> None:
    with open(path, "w", encoding="utf-8") as f:
        json.dump(stoi, f, ensure_ascii=False, indent=2)


# -------------------------
# Image preprocessing
# -------------------------

def preprocess_image(pil_img: Image.Image, img_height: int = 64) -> torch.Tensor:
    """

    Convert PIL image to normalized grayscale tensor [1, H, W].

    Keep aspect ratio: resize by height, width scales accordingly.

    """
    # grayscale
    img = pil_img.convert("L")
    w, h = img.size
    new_h = img_height
    new_w = int(round(w * (new_h / float(h))))
    new_w = max(new_w, 1)

    img = img.resize((new_w, new_h), Image.BILINEAR)

    arr = np.array(img).astype(np.float32) / 255.0  # [H, W] in 0..1
    # Normalize: mean/std typical for grayscale
    arr = (arr - 0.5) / 0.5  # -> [-1, 1]

    tensor = torch.from_numpy(arr).unsqueeze(0)  # [1, H, W]
    return tensor


# -------------------------
# Dataset + Collate
# -------------------------

class IAMLineTorchDataset(Dataset):
    def __init__(self, split: str, stoi: Dict[str, int], img_height: int = 64):
        ds = load_dataset("Teklia/IAM-line", split=split)
        self.ds = ds
        self.stoi = stoi
        self.img_height = img_height

    def __len__(self):
        return len(self.ds)

    def __getitem__(self, idx: int):
        sample = self.ds[idx]
        img = preprocess_image(sample["image"], img_height=self.img_height)  # [1,H,W]
        text = sample["text"]
        labels = torch.tensor(encode_text(text, self.stoi), dtype=torch.long)
        return img, labels, normalize_text(text)


@dataclass
class Batch:
    images: torch.Tensor          # [B, 1, H, Wmax]
    image_widths: torch.Tensor    # [B]
    targets: torch.Tensor         # [sum(T)]
    target_lengths: torch.Tensor  # [B]
    texts: List[str]              # original normalized texts


def ctc_collate_fn(batch_items) -> Batch:
    """

    Pads images on the right to max width.

    Creates flattened targets + target_lengths for CTCLoss.

    """
    imgs, labels_list, texts = zip(*batch_items)

    widths = torch.tensor([img.shape[-1] for img in imgs], dtype=torch.long)
    max_w = int(widths.max().item())
    b = len(imgs)
    c, h = imgs[0].shape[0], imgs[0].shape[1]

    padded = torch.zeros((b, c, h, max_w), dtype=torch.float32)
    for i, img in enumerate(imgs):
        w = img.shape[-1]
        padded[i, :, :, :w] = img

    target_lengths = torch.tensor([len(l) for l in labels_list], dtype=torch.long)
    targets = torch.cat([l for l in labels_list], dim=0) if target_lengths.sum() > 0 else torch.tensor([], dtype=torch.long)

    return Batch(
        images=padded,
        image_widths=widths,
        targets=targets,
        target_lengths=target_lengths,
        texts=list(texts),
    )


def make_loaders(batch_size: int = 8, img_height: int = 64, num_workers: int = 0):
    # build vocab from train + validation
    train_ds = load_dataset("Teklia/IAM-line", split="train")
    val_ds = load_dataset("Teklia/IAM-line", split="validation")

    all_texts = [x["text"] for x in train_ds] + [x["text"] for x in val_ds]
    stoi = build_vocab(all_texts)

    # save vocab for inference later
    save_vocab(stoi, "models/vocab.json")

    train = IAMLineTorchDataset("train", stoi, img_height=img_height)
    val = IAMLineTorchDataset("validation", stoi, img_height=img_height)

    train_loader = DataLoader(
        train, batch_size=batch_size, shuffle=True,
        num_workers=num_workers, collate_fn=ctc_collate_fn
    )
    val_loader = DataLoader(
        val, batch_size=batch_size, shuffle=False,
        num_workers=num_workers, collate_fn=ctc_collate_fn
    )
    return train_loader, val_loader, stoi