| """ |
| finetune.py |
| =========== |
| Fine-tune CRNN+CTC on generated civil registry form crops. |
| |
| Continues from best_model_v2.pth, trains on actual_annotations.json |
| + train_annotations.json, saves to best_model_v4.pth. |
| |
| Usage: |
| python finetune.py |
| |
| Output: |
| checkpoints/best_model_v4.pth |
| """ |
|
|
| import os |
| import sys |
| import json |
| import shutil |
| import torch |
| import torch.nn.functional as F |
| import torch.optim as optim |
| from torch.utils.data import DataLoader, ConcatDataset |
|
|
| sys.path.append('.') |
| from crnn_model import get_crnn_model |
| from dataset import CivilRegistryDataset, collate_fn |
|
|
| |
| CHECKPOINT_IN = "checkpoints/best_model_v3.pth" |
| CHECKPOINT_OUT = "checkpoints/best_model_v4.pth" |
|
|
| ACTUAL_ANN = "data/actual_annotations.json" |
| SYNTH_ANN = "data/train_annotations.json" |
| VAL_ANN = "data/val_annotations.json" |
|
|
| DRIVE_BACKUP = "/content/drive/MyDrive/crnn_finetune/CRNN+CTC/checkpoints/best_model_v4.pth" |
|
|
| IMG_HEIGHT = 64 |
| IMG_WIDTH = 512 |
| BATCH_SIZE = 32 |
|
|
| DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
| |
| PHASES = [ |
| |
| ("Phase 1 β CNN frozen, warm up on actual crops", 20, 1e-4, True, 5), |
| ("Phase 2 β Full model, main training", 30, 1e-5, False, 6), |
| ("Phase 3 β Full model, slow burn", 30, 5e-6, False, 6), |
| ("Phase 4 β Full model, final polish", 20, 1e-6, False, 5), |
| ] |
|
|
| |
| def fix_paths(json_path): |
| with open(json_path) as f: |
| ann = json.load(f) |
| changed = False |
| for a in ann: |
| if 'image_path' in a and '\\' in a['image_path']: |
| a['image_path'] = a['image_path'].replace('\\', '/') |
| changed = True |
| if changed: |
| with open(json_path, 'w') as f: |
| json.dump(ann, f) |
| print(f" Fixed backslash paths in {json_path}") |
|
|
| |
| def main(): |
| print("=" * 60) |
| print(" Fine-tuning CRNN+CTC on civil registry form crops") |
| print("=" * 60) |
| print(f" Device : {DEVICE}") |
| print(f" Checkpoint : {CHECKPOINT_IN}") |
|
|
| |
| for f in [CHECKPOINT_IN, VAL_ANN]: |
| if not os.path.exists(f): |
| print(f"ERROR: {f} not found.") |
| sys.exit(1) |
|
|
| |
| for ann_file in [ACTUAL_ANN, SYNTH_ANN, VAL_ANN]: |
| if os.path.exists(ann_file): |
| fix_paths(ann_file) |
|
|
| |
| datasets_to_merge = [] |
|
|
| |
| if os.path.exists(ACTUAL_ANN): |
| actual_dataset = CivilRegistryDataset( |
| data_dir=".", annotations_file=ACTUAL_ANN, |
| img_height=IMG_HEIGHT, img_width=IMG_WIDTH, augment=True |
| ) |
| datasets_to_merge.append(actual_dataset) |
| print(f" Actual crops: {len(actual_dataset)} (real scanned forms)") |
| else: |
| print(f" [!] {ACTUAL_ANN} not found") |
|
|
| |
| if os.path.exists(SYNTH_ANN): |
| synth_dataset = CivilRegistryDataset( |
| data_dir=".", annotations_file=SYNTH_ANN, |
| img_height=IMG_HEIGHT, img_width=IMG_WIDTH, augment=True |
| ) |
| datasets_to_merge.append(synth_dataset) |
| print(f" Synth crops : {len(synth_dataset)} (fully synthetic)") |
|
|
| if not datasets_to_merge: |
| print("ERROR: No training data found.") |
| sys.exit(1) |
|
|
| val_dataset = CivilRegistryDataset( |
| data_dir=".", annotations_file=VAL_ANN, |
| img_height=IMG_HEIGHT, img_width=IMG_WIDTH, augment=False |
| ) |
|
|
| train_dataset = ConcatDataset(datasets_to_merge) if len(datasets_to_merge) > 1 else datasets_to_merge[0] |
| print(f" Total train : {len(train_dataset)}") |
| print(f" Val : {len(val_dataset)}") |
|
|
| train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, |
| shuffle=True, num_workers=0, collate_fn=collate_fn) |
| val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, |
| shuffle=False, num_workers=0, collate_fn=collate_fn) |
|
|
| |
| print(f"\n Loading {CHECKPOINT_IN}...") |
| ckpt = torch.load(CHECKPOINT_IN, map_location=DEVICE, weights_only=False) |
| config = ckpt.get('config', {}) |
|
|
| ref_dataset = datasets_to_merge[0] |
| model = get_crnn_model( |
| model_type = config.get('model_type', 'standard'), |
| img_height = config.get('img_height', 64), |
| num_chars = ref_dataset.num_chars, |
| hidden_size = config.get('hidden_size', 128), |
| num_lstm_layers = config.get('num_lstm_layers', 1), |
| ).to(DEVICE) |
|
|
| missing, _ = model.load_state_dict(ckpt['model_state_dict'], strict=False) |
| if missing: |
| print(f" Note: {len(missing)} layers re-initialized (expected if vocab size changed)") |
| print(f" Loaded epoch {ckpt.get('epoch','?')} " |
| f"val_loss={ckpt.get('val_loss', ckpt.get('val_cer', 0)):.4f}") |
|
|
| criterion = torch.nn.CTCLoss(blank=0, reduction='mean', zero_infinity=True) |
| os.makedirs("checkpoints", exist_ok=True) |
|
|
| |
| def run_epoch(loader, training, optimizer=None): |
| model.train() if training else model.eval() |
| total, n = 0, 0 |
| ctx = torch.enable_grad() if training else torch.no_grad() |
| with ctx: |
| for images, targets, target_lengths, _ in loader: |
| images = images.to(DEVICE) |
| batch_size = images.size(0) |
| if training: |
| optimizer.zero_grad() |
| outputs = F.log_softmax(model(images), dim=2) |
| seq_len = outputs.size(0) |
| input_lengths = torch.full((batch_size,), seq_len, dtype=torch.long) |
| loss = criterion(outputs, targets, input_lengths, target_lengths) |
| if not torch.isnan(loss) and not torch.isinf(loss): |
| if training: |
| loss.backward() |
| torch.nn.utils.clip_grad_norm_(model.parameters(), 5) |
| optimizer.step() |
| total += loss.item() |
| n += 1 |
| return total / max(n, 1) |
|
|
| best_overall = float('inf') |
|
|
| for phase_name, epochs, lr, freeze_cnn, patience in PHASES: |
| print(f"\n{'='*60}") |
| print(f" {phase_name} LR={lr}") |
| print(f"{'='*60}") |
|
|
| for name, param in model.named_parameters(): |
| param.requires_grad = not (freeze_cnn and 'cnn' in name) |
|
|
| trainable = sum(p.numel() for p in model.parameters() if p.requires_grad) |
| print(f" Trainable params : {trainable:,}") |
|
|
| opt = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=lr) |
| sched = optim.lr_scheduler.ReduceLROnPlateau(opt, patience=2, factor=0.5) |
| best = float('inf') |
| wait = 0 |
|
|
| for epoch in range(1, epochs + 1): |
| tr = run_epoch(train_loader, True, opt) |
| vl = run_epoch(val_loader, False, None) |
| sched.step(vl) |
|
|
| if vl < best: |
| best = vl |
| wait = 0 |
| if vl < best_overall: |
| best_overall = vl |
| torch.save({ |
| **ckpt, |
| 'model_state_dict': model.state_dict(), |
| 'config': config, |
| 'char_to_idx': ref_dataset.char_to_idx, |
| 'idx_to_char': ref_dataset.idx_to_char, |
| 'epoch': epoch, |
| 'val_loss': vl, |
| }, CHECKPOINT_OUT) |
| print(f" Epoch {epoch:02d}/{epochs} Train={tr:.4f} Val={vl:.4f} <- saved") |
| else: |
| wait += 1 |
| print(f" Epoch {epoch:02d}/{epochs} Train={tr:.4f} Val={vl:.4f} (patience {wait}/{patience})") |
| if wait >= patience: |
| print(f" Early stopping.") |
| break |
|
|
| |
| if os.path.exists(CHECKPOINT_OUT) and os.path.exists(os.path.dirname(DRIVE_BACKUP)): |
| shutil.copy(CHECKPOINT_OUT, DRIVE_BACKUP) |
| print(f"\n Backed up to Drive: {DRIVE_BACKUP}") |
|
|
| print(f"\n{'='*60}") |
| print(f" Fine-tuning complete!") |
| print(f" Best val loss : {best_overall:.4f}") |
| print(f" Saved : {CHECKPOINT_OUT}") |
| print(f"{'='*60}") |
|
|
|
|
| if __name__ == '__main__': |
| main() |