ocr / IAM_train.py
hanz245's picture
set up
7111e1a
"""
IAM_train.py
============
Fine-tune the CRNN model using the IAM Handwriting Word Database.
Builds on top of EMNIST-trained model (best_model_emnist.pth).
FIXES vs old version:
- IMG_WIDTH 400 -> 512 (must match pipeline)
- Added log_softmax before CTCLoss (was missing β€” caused catastrophic forgetting)
- Phase 1: CNN FROZEN β€” only RNN+FC trained
- Phase 2: Full model at very low LR
- Loads from best_model_emnist.pth, falls back to best_model.pth
- Uses get_crnn_model() with correct architecture from checkpoint config
DATASET:
Download from: https://www.kaggle.com/datasets/nibinv23/iam-handwriting-word-database
Expected structure:
data/IAM/iam_words/
words/ <- word image folders (a01, a02, ...)
words.txt <- annotation file
USAGE:
python IAM_train.py --prepare # convert IAM -> annotation JSON
python IAM_train.py --train # fine-tune model
python IAM_train.py --prepare --train # do both
"""
import os
import sys
import json
import argparse
import random
from pathlib import Path
import torch
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, ConcatDataset
sys.path.append('.')
from crnn_model import get_crnn_model
from dataset import CivilRegistryDataset, collate_fn
# ─────────────────────────────────────────────
# CONFIG
# ─────────────────────────────────────────────
IAM_ROOT = "data/IAM/iam_words"
IAM_WORDS_TXT = f"{IAM_ROOT}/words.txt"
IAM_WORDS_DIR = f"{IAM_ROOT}/words"
TRAIN_ANN = "data/iam_train_annotations.json"
IAM_VAL_ANN = "data/iam_val_annotations.json" # written by --prepare (IAM word images)
SYNTH_VAL_ANN = "data/val_annotations.json" # real civil registry val set β€” never overwritten
TRAIN_IMG_DIR = "data/train/iam"
VAL_IMG_DIR = "data/val/iam"
IMG_HEIGHT = 64
IMG_WIDTH = 512 # FIXED: was 400 β€” must match pipeline
BATCH_SIZE = 32
VAL_SPLIT = 0.1
MAX_SAMPLES = 50000
# Load from EMNIST checkpoint, fall back to synthetic if not found
CHECKPOINT_IN = "checkpoints/best_model_emnist.pth"
CHECKPOINT_IN2 = "checkpoints/best_model.pth" # fallback
CHECKPOINT_OUT = "checkpoints/best_model_iam.pth"
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# ─────────────────────────────────────────────
# STEP 1 β€” PREPARE
# ─────────────────────────────────────────────
def prepare_iam():
from PIL import Image
print("\n" + "=" * 50)
print("STEP 1 β€” Preparing IAM dataset")
print("=" * 50)
if not os.path.exists(IAM_WORDS_TXT):
print(f"ERROR: {IAM_WORDS_TXT} not found!")
print("Download from: https://www.kaggle.com/datasets/nibinv23/iam-handwriting-word-database")
print("Expected structure:")
print(" data/IAM/iam_words/words.txt")
print(" data/IAM/iam_words/words/")
sys.exit(1)
os.makedirs(TRAIN_IMG_DIR, exist_ok=True)
os.makedirs(VAL_IMG_DIR, exist_ok=True)
entries = []
print(f" Reading {IAM_WORDS_TXT} ...")
with open(IAM_WORDS_TXT, "r") as f:
for line in f:
line = line.strip()
if not line or line.startswith("#"):
continue
parts = line.split(" ")
if len(parts) < 9:
continue
word_id = parts[0]
seg_result = parts[1]
text = parts[-1]
if seg_result != "ok":
continue
if len(text) < 1 or len(text) > 32:
continue
parts_id = word_id.split("-")
img_path = os.path.join(
IAM_WORDS_DIR,
parts_id[0],
f"{parts_id[0]}-{parts_id[1]}",
f"{word_id}.png"
)
if not os.path.exists(img_path):
continue
entries.append((img_path, text))
print(f" Found {len(entries)} valid word entries")
if MAX_SAMPLES and len(entries) > MAX_SAMPLES:
random.shuffle(entries)
entries = entries[:MAX_SAMPLES]
print(f" Limiting to {MAX_SAMPLES} samples")
random.shuffle(entries)
split_idx = int(len(entries) * (1 - VAL_SPLIT))
train_entries = entries[:split_idx]
val_entries = entries[split_idx:]
print(f" Train: {len(train_entries)} | Val: {len(val_entries)}")
print(" Copying and resizing images...")
def process_entries(entry_list, out_dir, prefix):
annotations = []
for i, (src_path, text) in enumerate(entry_list):
try:
img = Image.open(src_path).convert("RGB")
img = img.resize((IMG_WIDTH, IMG_HEIGHT)) # FIXED: 512x64
fname = f"iam_{prefix}_{i:06d}.jpg"
out_path = os.path.join(out_dir, fname)
img.save(out_path, quality=90)
annotations.append({"image_path": f"iam/{fname}", "text": text})
except Exception:
continue
if i % 5000 == 0:
print(f" {i}/{len(entry_list)} processed...")
return annotations
train_ann = process_entries(train_entries, TRAIN_IMG_DIR, "train")
val_ann = process_entries(val_entries, VAL_IMG_DIR, "val")
with open(TRAIN_ANN, "w") as f:
json.dump(train_ann, f, indent=2)
with open(IAM_VAL_ANN, "w") as f:
json.dump(val_ann, f, indent=2)
print(f"\n Train annotations -> {TRAIN_ANN} ({len(train_ann)} entries)")
print(f" Val annotations -> {IAM_VAL_ANN} ({len(val_ann)} entries)")
print("\n Done! Now run: python IAM_train.py --train")
# ─────────────────────────────────────────────
# STEP 2 β€” TRAIN
# ─────────────────────────────────────────────
def train_iam():
print("\n" + "=" * 55)
print("STEP 2 β€” Fine-tuning CRNN with IAM dataset")
print("=" * 55)
print(f" Device : {DEVICE}")
for ann_file in [TRAIN_ANN, SYNTH_VAL_ANN]:
if not os.path.exists(ann_file):
print(f"ERROR: {ann_file} not found! Run --prepare first.")
sys.exit(1)
train_dataset = CivilRegistryDataset(
data_dir="data/train", annotations_file=TRAIN_ANN,
img_height=IMG_HEIGHT, img_width=IMG_WIDTH, augment=True
)
# FIXED: mix synthetic data in so the model never forgets Filipino multi-word sequences
synth_dataset = CivilRegistryDataset(
data_dir="data/train", annotations_file="data/train_annotations.json",
img_height=IMG_HEIGHT, img_width=IMG_WIDTH, augment=True
)
mixed_train = ConcatDataset([train_dataset, synth_dataset])
val_dataset = CivilRegistryDataset(
data_dir="data/val", annotations_file=SYNTH_VAL_ANN,
img_height=IMG_HEIGHT, img_width=IMG_WIDTH, augment=False
)
print(f" IAM train : {len(train_dataset)}")
print(f" Synthetic train: {len(synth_dataset)}")
print(f" Mixed train : {len(mixed_train)}")
print(f" Val : {len(val_dataset)}")
train_loader = DataLoader(mixed_train, batch_size=BATCH_SIZE,
shuffle=True, num_workers=0, collate_fn=collate_fn)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE,
shuffle=False, num_workers=0, collate_fn=collate_fn)
# ── Load checkpoint (EMNIST preferred, synthetic fallback) ──
ckpt_path = CHECKPOINT_IN if os.path.exists(CHECKPOINT_IN) else CHECKPOINT_IN2
if not os.path.exists(ckpt_path):
print(f"ERROR: No checkpoint found at {CHECKPOINT_IN} or {CHECKPOINT_IN2}")
print("Run: python train.py then python train_with_emnist.py")
sys.exit(1)
print(f" Loading: {ckpt_path}")
ckpt = torch.load(ckpt_path, map_location=DEVICE, weights_only=False)
config = ckpt.get('config', {})
model = get_crnn_model(
model_type = config.get('model_type', 'standard'),
img_height = config.get('img_height', 64),
num_chars = train_dataset.num_chars,
hidden_size = config.get('hidden_size', 128),
num_lstm_layers = config.get('num_lstm_layers', 1),
).to(DEVICE)
missing, _ = model.load_state_dict(ckpt['model_state_dict'], strict=False)
if missing:
print(f" Note: {len(missing)} layers re-initialized")
print(f" Loaded epoch {ckpt.get('epoch', 'N/A')} "
f"val_loss={ckpt.get('val_loss', ckpt.get('val_cer', 0)):.4f}")
criterion = torch.nn.CTCLoss(blank=0, reduction='mean', zero_infinity=True)
os.makedirs("checkpoints", exist_ok=True)
def run_epoch(loader, training, optimizer=None):
model.train() if training else model.eval()
total, n = 0, 0
ctx = torch.enable_grad() if training else torch.no_grad()
with ctx:
for images, targets, target_lengths, _ in loader:
images = images.to(DEVICE)
batch_size = images.size(0)
if training:
optimizer.zero_grad()
# CRITICAL: log_softmax before CTCLoss
outputs = F.log_softmax(model(images), dim=2)
seq_len = outputs.size(0)
input_lengths = torch.full((batch_size,), seq_len, dtype=torch.long)
loss = criterion(outputs, targets, input_lengths, target_lengths)
if not torch.isnan(loss) and not torch.isinf(loss):
if training:
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 5)
optimizer.step()
total += loss.item()
n += 1
return total / max(n, 1)
def run_phase(num, epochs, lr, freeze_cnn, patience):
print(f"\n{'='*55}")
print(f" PHASE {num} β€” "
f"{'CNN FROZEN (RNN+FC only)' if freeze_cnn else 'FULL MODEL (all layers)'}"
f" LR={lr}")
print(f"{'='*55}")
for name, param in model.named_parameters():
param.requires_grad = not (freeze_cnn and 'cnn' in name)
trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f" Trainable params : {trainable:,}")
opt = optim.Adam(
filter(lambda p: p.requires_grad, model.parameters()), lr=lr)
sched = optim.lr_scheduler.ReduceLROnPlateau(opt, patience=3, factor=0.5)
best = float('inf')
counter = 0
for epoch in range(1, epochs + 1):
tr = run_epoch(train_loader, True, opt)
vl = run_epoch(val_loader, False, None)
sched.step(vl)
if vl < best:
best = vl
counter = 0
torch.save({
'model_state_dict': model.state_dict(),
'config': config,
'char_to_idx': train_dataset.char_to_idx,
'idx_to_char': train_dataset.idx_to_char,
'epoch': epoch,
'val_loss': vl, # FIXED: renamed from val_cer β€” this is val loss, not CER%
}, CHECKPOINT_OUT)
print(f" Epoch {epoch:02d}/{epochs} "
f"Train={tr:.4f} Val={vl:.4f} <- saved")
else:
counter += 1
print(f" Epoch {epoch:02d}/{epochs} "
f"Train={tr:.4f} Val={vl:.4f} "
f"(patience {counter}/{patience})")
if counter >= patience:
print(f" Early stopping at epoch {epoch}.")
break
return best
# Phase 1: Freeze CNN
p1 = run_phase(1, epochs=30, lr=1e-4, freeze_cnn=True, patience=7)
# Phase 2: Full model, very low LR
p2 = run_phase(2, epochs=20, lr=1e-6, freeze_cnn=False, patience=5)
print(f"\n{'='*55}")
print(f"IAM fine-tuning complete!")
print(f" Phase 1 best val loss : {p1:.4f}")
print(f" Phase 2 best val loss : {p2:.4f}")
print(f" Saved : {CHECKPOINT_OUT}")
print(f"\nNext step: collect physical certificate scans")
# ─────────────────────────────────────────────
# MAIN
# ─────────────────────────────────────────────
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--prepare", action="store_true")
parser.add_argument("--train", action="store_true")
args = parser.parse_args()
if not args.prepare and not args.train:
print("Usage:")
print(" python IAM_train.py --prepare # prepare dataset")
print(" python IAM_train.py --train # train model")
print(" python IAM_train.py --prepare --train # do both")
sys.exit(0)
if args.prepare:
prepare_iam()
if args.train:
train_iam()