File size: 5,068 Bytes
2411029 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 |
import json
import re
from dataclasses import dataclass
from typing import Dict, List, Tuple
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
from datasets import load_dataset
from PIL import Image
# -------------------------
# Text + Vocab utilities
# -------------------------
def normalize_text(s: str) -> str:
"""
Basic cleanup. Keep it conservative so we don't ruin labels.
- collapse multiple spaces
- strip ends
"""
s = s.replace("\t", " ")
s = re.sub(r"\s+", " ", s)
return s.strip()
def build_vocab(texts: List[str]) -> Dict[str, int]:
"""
Build char-level vocab for CTC.
index 0 reserved for CTC blank.
"""
charset = set()
for t in texts:
t = normalize_text(t)
for ch in t:
charset.add(ch)
chars = sorted(list(charset))
stoi = {"<BLANK>": 0}
for i, ch in enumerate(chars, start=1):
stoi[ch] = i
return stoi
def encode_text(text: str, stoi: Dict[str, int]) -> List[int]:
text = normalize_text(text)
return [stoi[ch] for ch in text if ch in stoi and ch != "<BLANK>"]
def save_vocab(stoi: Dict[str, int], path: str = "models/vocab.json") -> None:
with open(path, "w", encoding="utf-8") as f:
json.dump(stoi, f, ensure_ascii=False, indent=2)
# -------------------------
# Image preprocessing
# -------------------------
def preprocess_image(pil_img: Image.Image, img_height: int = 64) -> torch.Tensor:
"""
Convert PIL image to normalized grayscale tensor [1, H, W].
Keep aspect ratio: resize by height, width scales accordingly.
"""
# grayscale
img = pil_img.convert("L")
w, h = img.size
new_h = img_height
new_w = int(round(w * (new_h / float(h))))
new_w = max(new_w, 1)
img = img.resize((new_w, new_h), Image.BILINEAR)
arr = np.array(img).astype(np.float32) / 255.0 # [H, W] in 0..1
# Normalize: mean/std typical for grayscale
arr = (arr - 0.5) / 0.5 # -> [-1, 1]
tensor = torch.from_numpy(arr).unsqueeze(0) # [1, H, W]
return tensor
# -------------------------
# Dataset + Collate
# -------------------------
class IAMLineTorchDataset(Dataset):
def __init__(self, split: str, stoi: Dict[str, int], img_height: int = 64):
ds = load_dataset("Teklia/IAM-line", split=split)
self.ds = ds
self.stoi = stoi
self.img_height = img_height
def __len__(self):
return len(self.ds)
def __getitem__(self, idx: int):
sample = self.ds[idx]
img = preprocess_image(sample["image"], img_height=self.img_height) # [1,H,W]
text = sample["text"]
labels = torch.tensor(encode_text(text, self.stoi), dtype=torch.long)
return img, labels, normalize_text(text)
@dataclass
class Batch:
images: torch.Tensor # [B, 1, H, Wmax]
image_widths: torch.Tensor # [B]
targets: torch.Tensor # [sum(T)]
target_lengths: torch.Tensor # [B]
texts: List[str] # original normalized texts
def ctc_collate_fn(batch_items) -> Batch:
"""
Pads images on the right to max width.
Creates flattened targets + target_lengths for CTCLoss.
"""
imgs, labels_list, texts = zip(*batch_items)
widths = torch.tensor([img.shape[-1] for img in imgs], dtype=torch.long)
max_w = int(widths.max().item())
b = len(imgs)
c, h = imgs[0].shape[0], imgs[0].shape[1]
padded = torch.zeros((b, c, h, max_w), dtype=torch.float32)
for i, img in enumerate(imgs):
w = img.shape[-1]
padded[i, :, :, :w] = img
target_lengths = torch.tensor([len(l) for l in labels_list], dtype=torch.long)
targets = torch.cat([l for l in labels_list], dim=0) if target_lengths.sum() > 0 else torch.tensor([], dtype=torch.long)
return Batch(
images=padded,
image_widths=widths,
targets=targets,
target_lengths=target_lengths,
texts=list(texts),
)
def make_loaders(batch_size: int = 8, img_height: int = 64, num_workers: int = 0):
# build vocab from train + validation
train_ds = load_dataset("Teklia/IAM-line", split="train")
val_ds = load_dataset("Teklia/IAM-line", split="validation")
all_texts = [x["text"] for x in train_ds] + [x["text"] for x in val_ds]
stoi = build_vocab(all_texts)
# save vocab for inference later
save_vocab(stoi, "models/vocab.json")
train = IAMLineTorchDataset("train", stoi, img_height=img_height)
val = IAMLineTorchDataset("validation", stoi, img_height=img_height)
train_loader = DataLoader(
train, batch_size=batch_size, shuffle=True,
num_workers=num_workers, collate_fn=ctc_collate_fn
)
val_loader = DataLoader(
val, batch_size=batch_size, shuffle=False,
num_workers=num_workers, collate_fn=ctc_collate_fn
)
return train_loader, val_loader, stoi
|