lakshmi-charan's picture
Upload 15 files
2411029 verified
import json
import re
from dataclasses import dataclass
from typing import Dict, List, Tuple
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
from datasets import load_dataset
from PIL import Image
# -------------------------
# Text + Vocab utilities
# -------------------------
def normalize_text(s: str) -> str:
"""
Basic cleanup. Keep it conservative so we don't ruin labels.
- collapse multiple spaces
- strip ends
"""
s = s.replace("\t", " ")
s = re.sub(r"\s+", " ", s)
return s.strip()
def build_vocab(texts: List[str]) -> Dict[str, int]:
"""
Build char-level vocab for CTC.
index 0 reserved for CTC blank.
"""
charset = set()
for t in texts:
t = normalize_text(t)
for ch in t:
charset.add(ch)
chars = sorted(list(charset))
stoi = {"<BLANK>": 0}
for i, ch in enumerate(chars, start=1):
stoi[ch] = i
return stoi
def encode_text(text: str, stoi: Dict[str, int]) -> List[int]:
text = normalize_text(text)
return [stoi[ch] for ch in text if ch in stoi and ch != "<BLANK>"]
def save_vocab(stoi: Dict[str, int], path: str = "models/vocab.json") -> None:
with open(path, "w", encoding="utf-8") as f:
json.dump(stoi, f, ensure_ascii=False, indent=2)
# -------------------------
# Image preprocessing
# -------------------------
def preprocess_image(pil_img: Image.Image, img_height: int = 64) -> torch.Tensor:
"""
Convert PIL image to normalized grayscale tensor [1, H, W].
Keep aspect ratio: resize by height, width scales accordingly.
"""
# grayscale
img = pil_img.convert("L")
w, h = img.size
new_h = img_height
new_w = int(round(w * (new_h / float(h))))
new_w = max(new_w, 1)
img = img.resize((new_w, new_h), Image.BILINEAR)
arr = np.array(img).astype(np.float32) / 255.0 # [H, W] in 0..1
# Normalize: mean/std typical for grayscale
arr = (arr - 0.5) / 0.5 # -> [-1, 1]
tensor = torch.from_numpy(arr).unsqueeze(0) # [1, H, W]
return tensor
# -------------------------
# Dataset + Collate
# -------------------------
class IAMLineTorchDataset(Dataset):
def __init__(self, split: str, stoi: Dict[str, int], img_height: int = 64):
ds = load_dataset("Teklia/IAM-line", split=split)
self.ds = ds
self.stoi = stoi
self.img_height = img_height
def __len__(self):
return len(self.ds)
def __getitem__(self, idx: int):
sample = self.ds[idx]
img = preprocess_image(sample["image"], img_height=self.img_height) # [1,H,W]
text = sample["text"]
labels = torch.tensor(encode_text(text, self.stoi), dtype=torch.long)
return img, labels, normalize_text(text)
@dataclass
class Batch:
images: torch.Tensor # [B, 1, H, Wmax]
image_widths: torch.Tensor # [B]
targets: torch.Tensor # [sum(T)]
target_lengths: torch.Tensor # [B]
texts: List[str] # original normalized texts
def ctc_collate_fn(batch_items) -> Batch:
"""
Pads images on the right to max width.
Creates flattened targets + target_lengths for CTCLoss.
"""
imgs, labels_list, texts = zip(*batch_items)
widths = torch.tensor([img.shape[-1] for img in imgs], dtype=torch.long)
max_w = int(widths.max().item())
b = len(imgs)
c, h = imgs[0].shape[0], imgs[0].shape[1]
padded = torch.zeros((b, c, h, max_w), dtype=torch.float32)
for i, img in enumerate(imgs):
w = img.shape[-1]
padded[i, :, :, :w] = img
target_lengths = torch.tensor([len(l) for l in labels_list], dtype=torch.long)
targets = torch.cat([l for l in labels_list], dim=0) if target_lengths.sum() > 0 else torch.tensor([], dtype=torch.long)
return Batch(
images=padded,
image_widths=widths,
targets=targets,
target_lengths=target_lengths,
texts=list(texts),
)
def make_loaders(batch_size: int = 8, img_height: int = 64, num_workers: int = 0):
# build vocab from train + validation
train_ds = load_dataset("Teklia/IAM-line", split="train")
val_ds = load_dataset("Teklia/IAM-line", split="validation")
all_texts = [x["text"] for x in train_ds] + [x["text"] for x in val_ds]
stoi = build_vocab(all_texts)
# save vocab for inference later
save_vocab(stoi, "models/vocab.json")
train = IAMLineTorchDataset("train", stoi, img_height=img_height)
val = IAMLineTorchDataset("validation", stoi, img_height=img_height)
train_loader = DataLoader(
train, batch_size=batch_size, shuffle=True,
num_workers=num_workers, collate_fn=ctc_collate_fn
)
val_loader = DataLoader(
val, batch_size=batch_size, shuffle=False,
num_workers=num_workers, collate_fn=ctc_collate_fn
)
return train_loader, val_loader, stoi