|
|
import json
|
|
|
import re
|
|
|
from dataclasses import dataclass
|
|
|
from typing import Dict, List, Tuple
|
|
|
|
|
|
import numpy as np
|
|
|
import torch
|
|
|
from torch.utils.data import Dataset, DataLoader
|
|
|
from datasets import load_dataset
|
|
|
from PIL import Image
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def normalize_text(s: str) -> str:
|
|
|
"""
|
|
|
Basic cleanup. Keep it conservative so we don't ruin labels.
|
|
|
- collapse multiple spaces
|
|
|
- strip ends
|
|
|
"""
|
|
|
s = s.replace("\t", " ")
|
|
|
s = re.sub(r"\s+", " ", s)
|
|
|
return s.strip()
|
|
|
|
|
|
|
|
|
def build_vocab(texts: List[str]) -> Dict[str, int]:
|
|
|
"""
|
|
|
Build char-level vocab for CTC.
|
|
|
index 0 reserved for CTC blank.
|
|
|
"""
|
|
|
charset = set()
|
|
|
for t in texts:
|
|
|
t = normalize_text(t)
|
|
|
for ch in t:
|
|
|
charset.add(ch)
|
|
|
|
|
|
chars = sorted(list(charset))
|
|
|
stoi = {"<BLANK>": 0}
|
|
|
for i, ch in enumerate(chars, start=1):
|
|
|
stoi[ch] = i
|
|
|
return stoi
|
|
|
|
|
|
|
|
|
def encode_text(text: str, stoi: Dict[str, int]) -> List[int]:
|
|
|
text = normalize_text(text)
|
|
|
return [stoi[ch] for ch in text if ch in stoi and ch != "<BLANK>"]
|
|
|
|
|
|
|
|
|
def save_vocab(stoi: Dict[str, int], path: str = "models/vocab.json") -> None:
|
|
|
with open(path, "w", encoding="utf-8") as f:
|
|
|
json.dump(stoi, f, ensure_ascii=False, indent=2)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def preprocess_image(pil_img: Image.Image, img_height: int = 64) -> torch.Tensor:
|
|
|
"""
|
|
|
Convert PIL image to normalized grayscale tensor [1, H, W].
|
|
|
Keep aspect ratio: resize by height, width scales accordingly.
|
|
|
"""
|
|
|
|
|
|
img = pil_img.convert("L")
|
|
|
w, h = img.size
|
|
|
new_h = img_height
|
|
|
new_w = int(round(w * (new_h / float(h))))
|
|
|
new_w = max(new_w, 1)
|
|
|
|
|
|
img = img.resize((new_w, new_h), Image.BILINEAR)
|
|
|
|
|
|
arr = np.array(img).astype(np.float32) / 255.0
|
|
|
|
|
|
arr = (arr - 0.5) / 0.5
|
|
|
|
|
|
tensor = torch.from_numpy(arr).unsqueeze(0)
|
|
|
return tensor
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class IAMLineTorchDataset(Dataset):
|
|
|
def __init__(self, split: str, stoi: Dict[str, int], img_height: int = 64):
|
|
|
ds = load_dataset("Teklia/IAM-line", split=split)
|
|
|
self.ds = ds
|
|
|
self.stoi = stoi
|
|
|
self.img_height = img_height
|
|
|
|
|
|
def __len__(self):
|
|
|
return len(self.ds)
|
|
|
|
|
|
def __getitem__(self, idx: int):
|
|
|
sample = self.ds[idx]
|
|
|
img = preprocess_image(sample["image"], img_height=self.img_height)
|
|
|
text = sample["text"]
|
|
|
labels = torch.tensor(encode_text(text, self.stoi), dtype=torch.long)
|
|
|
return img, labels, normalize_text(text)
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
class Batch:
|
|
|
images: torch.Tensor
|
|
|
image_widths: torch.Tensor
|
|
|
targets: torch.Tensor
|
|
|
target_lengths: torch.Tensor
|
|
|
texts: List[str]
|
|
|
|
|
|
|
|
|
def ctc_collate_fn(batch_items) -> Batch:
|
|
|
"""
|
|
|
Pads images on the right to max width.
|
|
|
Creates flattened targets + target_lengths for CTCLoss.
|
|
|
"""
|
|
|
imgs, labels_list, texts = zip(*batch_items)
|
|
|
|
|
|
widths = torch.tensor([img.shape[-1] for img in imgs], dtype=torch.long)
|
|
|
max_w = int(widths.max().item())
|
|
|
b = len(imgs)
|
|
|
c, h = imgs[0].shape[0], imgs[0].shape[1]
|
|
|
|
|
|
padded = torch.zeros((b, c, h, max_w), dtype=torch.float32)
|
|
|
for i, img in enumerate(imgs):
|
|
|
w = img.shape[-1]
|
|
|
padded[i, :, :, :w] = img
|
|
|
|
|
|
target_lengths = torch.tensor([len(l) for l in labels_list], dtype=torch.long)
|
|
|
targets = torch.cat([l for l in labels_list], dim=0) if target_lengths.sum() > 0 else torch.tensor([], dtype=torch.long)
|
|
|
|
|
|
return Batch(
|
|
|
images=padded,
|
|
|
image_widths=widths,
|
|
|
targets=targets,
|
|
|
target_lengths=target_lengths,
|
|
|
texts=list(texts),
|
|
|
)
|
|
|
|
|
|
|
|
|
def make_loaders(batch_size: int = 8, img_height: int = 64, num_workers: int = 0):
|
|
|
|
|
|
train_ds = load_dataset("Teklia/IAM-line", split="train")
|
|
|
val_ds = load_dataset("Teklia/IAM-line", split="validation")
|
|
|
|
|
|
all_texts = [x["text"] for x in train_ds] + [x["text"] for x in val_ds]
|
|
|
stoi = build_vocab(all_texts)
|
|
|
|
|
|
|
|
|
save_vocab(stoi, "models/vocab.json")
|
|
|
|
|
|
train = IAMLineTorchDataset("train", stoi, img_height=img_height)
|
|
|
val = IAMLineTorchDataset("validation", stoi, img_height=img_height)
|
|
|
|
|
|
train_loader = DataLoader(
|
|
|
train, batch_size=batch_size, shuffle=True,
|
|
|
num_workers=num_workers, collate_fn=ctc_collate_fn
|
|
|
)
|
|
|
val_loader = DataLoader(
|
|
|
val, batch_size=batch_size, shuffle=False,
|
|
|
num_workers=num_workers, collate_fn=ctc_collate_fn
|
|
|
)
|
|
|
return train_loader, val_loader, stoi
|
|
|
|