| """ |
| IAM_train.py |
| ============ |
| Fine-tune the CRNN model using the IAM Handwriting Word Database. |
| Builds on top of EMNIST-trained model (best_model_emnist.pth). |
| |
| FIXES vs old version: |
| - IMG_WIDTH 400 -> 512 (must match pipeline) |
| - Added log_softmax before CTCLoss (was missing β caused catastrophic forgetting) |
| - Phase 1: CNN FROZEN β only RNN+FC trained |
| - Phase 2: Full model at very low LR |
| - Loads from best_model_emnist.pth, falls back to best_model.pth |
| - Uses get_crnn_model() with correct architecture from checkpoint config |
| |
| DATASET: |
| Download from: https://www.kaggle.com/datasets/nibinv23/iam-handwriting-word-database |
| Expected structure: |
| data/IAM/iam_words/ |
| words/ <- word image folders (a01, a02, ...) |
| words.txt <- annotation file |
| |
| USAGE: |
| python IAM_train.py --prepare # convert IAM -> annotation JSON |
| python IAM_train.py --train # fine-tune model |
| python IAM_train.py --prepare --train # do both |
| """ |
|
|
| import os |
| import sys |
| import json |
| import argparse |
| import random |
| from pathlib import Path |
|
|
| import torch |
| import torch.nn.functional as F |
| import torch.optim as optim |
| from torch.utils.data import DataLoader, ConcatDataset |
|
|
| sys.path.append('.') |
| from crnn_model import get_crnn_model |
| from dataset import CivilRegistryDataset, collate_fn |
|
|
| |
| |
| |
| IAM_ROOT = "data/IAM/iam_words" |
| IAM_WORDS_TXT = f"{IAM_ROOT}/words.txt" |
| IAM_WORDS_DIR = f"{IAM_ROOT}/words" |
|
|
| TRAIN_ANN = "data/iam_train_annotations.json" |
| IAM_VAL_ANN = "data/iam_val_annotations.json" |
| SYNTH_VAL_ANN = "data/val_annotations.json" |
| TRAIN_IMG_DIR = "data/train/iam" |
| VAL_IMG_DIR = "data/val/iam" |
|
|
| IMG_HEIGHT = 64 |
| IMG_WIDTH = 512 |
| BATCH_SIZE = 32 |
| VAL_SPLIT = 0.1 |
| MAX_SAMPLES = 50000 |
|
|
| |
| CHECKPOINT_IN = "checkpoints/best_model_emnist.pth" |
| CHECKPOINT_IN2 = "checkpoints/best_model.pth" |
| CHECKPOINT_OUT = "checkpoints/best_model_iam.pth" |
|
|
| DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
| |
| |
| |
| def prepare_iam(): |
| from PIL import Image |
|
|
| print("\n" + "=" * 50) |
| print("STEP 1 β Preparing IAM dataset") |
| print("=" * 50) |
|
|
| if not os.path.exists(IAM_WORDS_TXT): |
| print(f"ERROR: {IAM_WORDS_TXT} not found!") |
| print("Download from: https://www.kaggle.com/datasets/nibinv23/iam-handwriting-word-database") |
| print("Expected structure:") |
| print(" data/IAM/iam_words/words.txt") |
| print(" data/IAM/iam_words/words/") |
| sys.exit(1) |
|
|
| os.makedirs(TRAIN_IMG_DIR, exist_ok=True) |
| os.makedirs(VAL_IMG_DIR, exist_ok=True) |
|
|
| entries = [] |
| print(f" Reading {IAM_WORDS_TXT} ...") |
| with open(IAM_WORDS_TXT, "r") as f: |
| for line in f: |
| line = line.strip() |
| if not line or line.startswith("#"): |
| continue |
| parts = line.split(" ") |
| if len(parts) < 9: |
| continue |
| word_id = parts[0] |
| seg_result = parts[1] |
| text = parts[-1] |
| if seg_result != "ok": |
| continue |
| if len(text) < 1 or len(text) > 32: |
| continue |
| parts_id = word_id.split("-") |
| img_path = os.path.join( |
| IAM_WORDS_DIR, |
| parts_id[0], |
| f"{parts_id[0]}-{parts_id[1]}", |
| f"{word_id}.png" |
| ) |
| if not os.path.exists(img_path): |
| continue |
| entries.append((img_path, text)) |
|
|
| print(f" Found {len(entries)} valid word entries") |
|
|
| if MAX_SAMPLES and len(entries) > MAX_SAMPLES: |
| random.shuffle(entries) |
| entries = entries[:MAX_SAMPLES] |
| print(f" Limiting to {MAX_SAMPLES} samples") |
|
|
| random.shuffle(entries) |
| split_idx = int(len(entries) * (1 - VAL_SPLIT)) |
| train_entries = entries[:split_idx] |
| val_entries = entries[split_idx:] |
| print(f" Train: {len(train_entries)} | Val: {len(val_entries)}") |
| print(" Copying and resizing images...") |
|
|
| def process_entries(entry_list, out_dir, prefix): |
| annotations = [] |
| for i, (src_path, text) in enumerate(entry_list): |
| try: |
| img = Image.open(src_path).convert("RGB") |
| img = img.resize((IMG_WIDTH, IMG_HEIGHT)) |
| fname = f"iam_{prefix}_{i:06d}.jpg" |
| out_path = os.path.join(out_dir, fname) |
| img.save(out_path, quality=90) |
| annotations.append({"image_path": f"iam/{fname}", "text": text}) |
| except Exception: |
| continue |
| if i % 5000 == 0: |
| print(f" {i}/{len(entry_list)} processed...") |
| return annotations |
|
|
| train_ann = process_entries(train_entries, TRAIN_IMG_DIR, "train") |
| val_ann = process_entries(val_entries, VAL_IMG_DIR, "val") |
|
|
| with open(TRAIN_ANN, "w") as f: |
| json.dump(train_ann, f, indent=2) |
| with open(IAM_VAL_ANN, "w") as f: |
| json.dump(val_ann, f, indent=2) |
|
|
| print(f"\n Train annotations -> {TRAIN_ANN} ({len(train_ann)} entries)") |
| print(f" Val annotations -> {IAM_VAL_ANN} ({len(val_ann)} entries)") |
| print("\n Done! Now run: python IAM_train.py --train") |
|
|
|
|
| |
| |
| |
| def train_iam(): |
| print("\n" + "=" * 55) |
| print("STEP 2 β Fine-tuning CRNN with IAM dataset") |
| print("=" * 55) |
| print(f" Device : {DEVICE}") |
|
|
| for ann_file in [TRAIN_ANN, SYNTH_VAL_ANN]: |
| if not os.path.exists(ann_file): |
| print(f"ERROR: {ann_file} not found! Run --prepare first.") |
| sys.exit(1) |
|
|
| train_dataset = CivilRegistryDataset( |
| data_dir="data/train", annotations_file=TRAIN_ANN, |
| img_height=IMG_HEIGHT, img_width=IMG_WIDTH, augment=True |
| ) |
| |
| synth_dataset = CivilRegistryDataset( |
| data_dir="data/train", annotations_file="data/train_annotations.json", |
| img_height=IMG_HEIGHT, img_width=IMG_WIDTH, augment=True |
| ) |
| mixed_train = ConcatDataset([train_dataset, synth_dataset]) |
| val_dataset = CivilRegistryDataset( |
| data_dir="data/val", annotations_file=SYNTH_VAL_ANN, |
| img_height=IMG_HEIGHT, img_width=IMG_WIDTH, augment=False |
| ) |
| print(f" IAM train : {len(train_dataset)}") |
| print(f" Synthetic train: {len(synth_dataset)}") |
| print(f" Mixed train : {len(mixed_train)}") |
| print(f" Val : {len(val_dataset)}") |
|
|
| train_loader = DataLoader(mixed_train, batch_size=BATCH_SIZE, |
| shuffle=True, num_workers=0, collate_fn=collate_fn) |
| val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, |
| shuffle=False, num_workers=0, collate_fn=collate_fn) |
|
|
| |
| ckpt_path = CHECKPOINT_IN if os.path.exists(CHECKPOINT_IN) else CHECKPOINT_IN2 |
| if not os.path.exists(ckpt_path): |
| print(f"ERROR: No checkpoint found at {CHECKPOINT_IN} or {CHECKPOINT_IN2}") |
| print("Run: python train.py then python train_with_emnist.py") |
| sys.exit(1) |
|
|
| print(f" Loading: {ckpt_path}") |
| ckpt = torch.load(ckpt_path, map_location=DEVICE, weights_only=False) |
| config = ckpt.get('config', {}) |
|
|
| model = get_crnn_model( |
| model_type = config.get('model_type', 'standard'), |
| img_height = config.get('img_height', 64), |
| num_chars = train_dataset.num_chars, |
| hidden_size = config.get('hidden_size', 128), |
| num_lstm_layers = config.get('num_lstm_layers', 1), |
| ).to(DEVICE) |
|
|
| missing, _ = model.load_state_dict(ckpt['model_state_dict'], strict=False) |
| if missing: |
| print(f" Note: {len(missing)} layers re-initialized") |
| print(f" Loaded epoch {ckpt.get('epoch', 'N/A')} " |
| f"val_loss={ckpt.get('val_loss', ckpt.get('val_cer', 0)):.4f}") |
|
|
| criterion = torch.nn.CTCLoss(blank=0, reduction='mean', zero_infinity=True) |
| os.makedirs("checkpoints", exist_ok=True) |
|
|
| def run_epoch(loader, training, optimizer=None): |
| model.train() if training else model.eval() |
| total, n = 0, 0 |
| ctx = torch.enable_grad() if training else torch.no_grad() |
| with ctx: |
| for images, targets, target_lengths, _ in loader: |
| images = images.to(DEVICE) |
| batch_size = images.size(0) |
| if training: |
| optimizer.zero_grad() |
| |
| outputs = F.log_softmax(model(images), dim=2) |
| seq_len = outputs.size(0) |
| input_lengths = torch.full((batch_size,), seq_len, dtype=torch.long) |
| loss = criterion(outputs, targets, input_lengths, target_lengths) |
| if not torch.isnan(loss) and not torch.isinf(loss): |
| if training: |
| loss.backward() |
| torch.nn.utils.clip_grad_norm_(model.parameters(), 5) |
| optimizer.step() |
| total += loss.item() |
| n += 1 |
| return total / max(n, 1) |
|
|
| def run_phase(num, epochs, lr, freeze_cnn, patience): |
| print(f"\n{'='*55}") |
| print(f" PHASE {num} β " |
| f"{'CNN FROZEN (RNN+FC only)' if freeze_cnn else 'FULL MODEL (all layers)'}" |
| f" LR={lr}") |
| print(f"{'='*55}") |
|
|
| for name, param in model.named_parameters(): |
| param.requires_grad = not (freeze_cnn and 'cnn' in name) |
|
|
| trainable = sum(p.numel() for p in model.parameters() if p.requires_grad) |
| print(f" Trainable params : {trainable:,}") |
|
|
| opt = optim.Adam( |
| filter(lambda p: p.requires_grad, model.parameters()), lr=lr) |
| sched = optim.lr_scheduler.ReduceLROnPlateau(opt, patience=3, factor=0.5) |
| best = float('inf') |
| counter = 0 |
|
|
| for epoch in range(1, epochs + 1): |
| tr = run_epoch(train_loader, True, opt) |
| vl = run_epoch(val_loader, False, None) |
| sched.step(vl) |
|
|
| if vl < best: |
| best = vl |
| counter = 0 |
| torch.save({ |
| 'model_state_dict': model.state_dict(), |
| 'config': config, |
| 'char_to_idx': train_dataset.char_to_idx, |
| 'idx_to_char': train_dataset.idx_to_char, |
| 'epoch': epoch, |
| 'val_loss': vl, |
| }, CHECKPOINT_OUT) |
| print(f" Epoch {epoch:02d}/{epochs} " |
| f"Train={tr:.4f} Val={vl:.4f} <- saved") |
| else: |
| counter += 1 |
| print(f" Epoch {epoch:02d}/{epochs} " |
| f"Train={tr:.4f} Val={vl:.4f} " |
| f"(patience {counter}/{patience})") |
| if counter >= patience: |
| print(f" Early stopping at epoch {epoch}.") |
| break |
| return best |
|
|
| |
| p1 = run_phase(1, epochs=30, lr=1e-4, freeze_cnn=True, patience=7) |
| |
| p2 = run_phase(2, epochs=20, lr=1e-6, freeze_cnn=False, patience=5) |
|
|
| print(f"\n{'='*55}") |
| print(f"IAM fine-tuning complete!") |
| print(f" Phase 1 best val loss : {p1:.4f}") |
| print(f" Phase 2 best val loss : {p2:.4f}") |
| print(f" Saved : {CHECKPOINT_OUT}") |
| print(f"\nNext step: collect physical certificate scans") |
|
|
|
|
| |
| |
| |
| if __name__ == "__main__": |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--prepare", action="store_true") |
| parser.add_argument("--train", action="store_true") |
| args = parser.parse_args() |
|
|
| if not args.prepare and not args.train: |
| print("Usage:") |
| print(" python IAM_train.py --prepare # prepare dataset") |
| print(" python IAM_train.py --train # train model") |
| print(" python IAM_train.py --prepare --train # do both") |
| sys.exit(0) |
|
|
| if args.prepare: |
| prepare_iam() |
| if args.train: |
| train_iam() |