import json import re from dataclasses import dataclass from typing import Dict, List, Tuple import numpy as np import torch from torch.utils.data import Dataset, DataLoader from datasets import load_dataset from PIL import Image # ------------------------- # Text + Vocab utilities # ------------------------- def normalize_text(s: str) -> str: """ Basic cleanup. Keep it conservative so we don't ruin labels. - collapse multiple spaces - strip ends """ s = s.replace("\t", " ") s = re.sub(r"\s+", " ", s) return s.strip() def build_vocab(texts: List[str]) -> Dict[str, int]: """ Build char-level vocab for CTC. index 0 reserved for CTC blank. """ charset = set() for t in texts: t = normalize_text(t) for ch in t: charset.add(ch) chars = sorted(list(charset)) stoi = {"": 0} for i, ch in enumerate(chars, start=1): stoi[ch] = i return stoi def encode_text(text: str, stoi: Dict[str, int]) -> List[int]: text = normalize_text(text) return [stoi[ch] for ch in text if ch in stoi and ch != ""] def save_vocab(stoi: Dict[str, int], path: str = "models/vocab.json") -> None: with open(path, "w", encoding="utf-8") as f: json.dump(stoi, f, ensure_ascii=False, indent=2) # ------------------------- # Image preprocessing # ------------------------- def preprocess_image(pil_img: Image.Image, img_height: int = 64) -> torch.Tensor: """ Convert PIL image to normalized grayscale tensor [1, H, W]. Keep aspect ratio: resize by height, width scales accordingly. """ # grayscale img = pil_img.convert("L") w, h = img.size new_h = img_height new_w = int(round(w * (new_h / float(h)))) new_w = max(new_w, 1) img = img.resize((new_w, new_h), Image.BILINEAR) arr = np.array(img).astype(np.float32) / 255.0 # [H, W] in 0..1 # Normalize: mean/std typical for grayscale arr = (arr - 0.5) / 0.5 # -> [-1, 1] tensor = torch.from_numpy(arr).unsqueeze(0) # [1, H, W] return tensor # ------------------------- # Dataset + Collate # ------------------------- class IAMLineTorchDataset(Dataset): def __init__(self, split: str, stoi: Dict[str, int], img_height: int = 64): ds = load_dataset("Teklia/IAM-line", split=split) self.ds = ds self.stoi = stoi self.img_height = img_height def __len__(self): return len(self.ds) def __getitem__(self, idx: int): sample = self.ds[idx] img = preprocess_image(sample["image"], img_height=self.img_height) # [1,H,W] text = sample["text"] labels = torch.tensor(encode_text(text, self.stoi), dtype=torch.long) return img, labels, normalize_text(text) @dataclass class Batch: images: torch.Tensor # [B, 1, H, Wmax] image_widths: torch.Tensor # [B] targets: torch.Tensor # [sum(T)] target_lengths: torch.Tensor # [B] texts: List[str] # original normalized texts def ctc_collate_fn(batch_items) -> Batch: """ Pads images on the right to max width. Creates flattened targets + target_lengths for CTCLoss. """ imgs, labels_list, texts = zip(*batch_items) widths = torch.tensor([img.shape[-1] for img in imgs], dtype=torch.long) max_w = int(widths.max().item()) b = len(imgs) c, h = imgs[0].shape[0], imgs[0].shape[1] padded = torch.zeros((b, c, h, max_w), dtype=torch.float32) for i, img in enumerate(imgs): w = img.shape[-1] padded[i, :, :, :w] = img target_lengths = torch.tensor([len(l) for l in labels_list], dtype=torch.long) targets = torch.cat([l for l in labels_list], dim=0) if target_lengths.sum() > 0 else torch.tensor([], dtype=torch.long) return Batch( images=padded, image_widths=widths, targets=targets, target_lengths=target_lengths, texts=list(texts), ) def make_loaders(batch_size: int = 8, img_height: int = 64, num_workers: int = 0): # build vocab from train + validation train_ds = load_dataset("Teklia/IAM-line", split="train") val_ds = load_dataset("Teklia/IAM-line", split="validation") all_texts = [x["text"] for x in train_ds] + [x["text"] for x in val_ds] stoi = build_vocab(all_texts) # save vocab for inference later save_vocab(stoi, "models/vocab.json") train = IAMLineTorchDataset("train", stoi, img_height=img_height) val = IAMLineTorchDataset("validation", stoi, img_height=img_height) train_loader = DataLoader( train, batch_size=batch_size, shuffle=True, num_workers=num_workers, collate_fn=ctc_collate_fn ) val_loader = DataLoader( val, batch_size=batch_size, shuffle=False, num_workers=num_workers, collate_fn=ctc_collate_fn ) return train_loader, val_loader, stoi