# Copyright (C) 2021-2025, Mindee. # This program is licensed under the Apache License 2.0. # See LICENSE or go to for full license details. import numpy as np import torch from tqdm import tqdm from doctr import datasets from doctr import transforms as T from doctr.models import ocr_predictor from doctr.utils.geometry import extract_crops, extract_rcrops from doctr.utils.metrics import LocalizationConfusion, OCRMetric, TextMatch def _pct(val): return "N/A" if val is None else f"{val:.2%}" def main(args): if not args.rotation: args.eval_straight = True input_shape = (args.size, args.size) # We define a transformation function which does transform the annotation # to the required format for the Resize transformation def _transform(img, target): boxes = target["boxes"] transformed_img, transformed_boxes = T.Resize( input_shape, preserve_aspect_ratio=args.keep_ratio, symmetric_pad=args.symmetric_pad )(img, boxes) return transformed_img, {"boxes": transformed_boxes, "labels": target["labels"]} predictor = ocr_predictor( args.detection, args.recognition, pretrained=True, reco_bs=args.batch_size, preserve_aspect_ratio=False, # we handle the transformation directly in the dataset so this is set to False symmetric_pad=False, # we handle the transformation directly in the dataset so this is set to False assume_straight_pages=not args.rotation, ) if torch.cuda.is_available(): predictor = predictor.cuda() if args.img_folder and args.label_file: testset = datasets.OCRDataset( img_folder=args.img_folder, label_file=args.label_file, sample_transforms=_transform, ) sets = [testset] else: train_set = datasets.__dict__[args.dataset]( train=True, download=True, use_polygons=not args.eval_straight, sample_transforms=_transform, ) val_set = datasets.__dict__[args.dataset]( train=False, download=True, use_polygons=not args.eval_straight, sample_transforms=_transform, ) sets = [train_set, val_set] reco_metric = TextMatch() det_metric = LocalizationConfusion(iou_thresh=args.iou, use_polygons=not args.eval_straight) e2e_metric = OCRMetric(iou_thresh=args.iou, use_polygons=not args.eval_straight) sample_idx = 0 extraction_fn = extract_crops if args.eval_straight else extract_rcrops for dataset in sets: for page, target in tqdm(dataset): if isinstance(page, torch.Tensor): page = np.transpose(page.numpy(), (1, 2, 0)) # GT gt_boxes = target["boxes"] gt_labels = target["labels"] if args.img_folder and args.label_file: x, y, w, h = gt_boxes[:, 0], gt_boxes[:, 1], gt_boxes[:, 2], gt_boxes[:, 3] xmin, ymin = np.clip(x - w / 2, 0, 1), np.clip(y - h / 2, 0, 1) xmax, ymax = np.clip(x + w / 2, 0, 1), np.clip(y + h / 2, 0, 1) gt_boxes = np.stack([xmin, ymin, xmax, ymax], axis=-1) # Forward with torch.no_grad(): out = predictor(page[None, ...]) crops = extraction_fn(page, gt_boxes) reco_out = predictor.reco_predictor(crops) if len(reco_out): reco_words, _ = zip(*reco_out) else: reco_words = [] # Unpack preds pred_boxes = [] pred_labels = [] for page in out.pages: height, width = page.dimensions for block in page.blocks: for line in block.lines: for word in line.words: if not args.rotation: (a, b), (c, d) = word.geometry else: ( [x1, y1], [x2, y2], [x3, y3], [x4, y4], ) = word.geometry if np.issubdtype(gt_boxes.dtype, np.integer): if not args.rotation: pred_boxes.append([ int(a * width), int(b * height), int(c * width), int(d * height), ]) else: if args.eval_straight: pred_boxes.append([ int(width * min(x1, x2, x3, x4)), int(height * min(y1, y2, y3, y4)), int(width * max(x1, x2, x3, x4)), int(height * max(y1, y2, y3, y4)), ]) else: pred_boxes.append([ [int(x1 * width), int(y1 * height)], [int(x2 * width), int(y2 * height)], [int(x3 * width), int(y3 * height)], [int(x4 * width), int(y4 * height)], ]) else: if not args.rotation: pred_boxes.append([a, b, c, d]) else: if args.eval_straight: pred_boxes.append([ min(x1, x2, x3, x4), min(y1, y2, y3, y4), max(x1, x2, x3, x4), max(y1, y2, y3, y4), ]) else: pred_boxes.append([[x1, y1], [x2, y2], [x3, y3], [x4, y4]]) pred_labels.append(word.value) # Update the metric det_metric.update(gt_boxes, np.asarray(pred_boxes)) reco_metric.update(gt_labels, reco_words) e2e_metric.update(gt_boxes, np.asarray(pred_boxes), gt_labels, pred_labels) # Loop break sample_idx += 1 if isinstance(args.samples, int) and args.samples == sample_idx: break if isinstance(args.samples, int) and args.samples == sample_idx: break # Unpack aggregated metrics print( f"Model Evaluation (model= {args.detection} + {args.recognition}, " f"dataset={'OCRDataset' if args.img_folder else args.dataset})" ) recall, precision, mean_iou = det_metric.summary() print(f"Text Detection - Recall: {_pct(recall)}, Precision: {_pct(precision)}, Mean IoU: {_pct(mean_iou)}") acc = reco_metric.summary() print(f"Text Recognition - Accuracy: {_pct(acc['raw'])} (unicase: {_pct(acc['unicase'])})") recall, precision, mean_iou = e2e_metric.summary() print( f"OCR - Recall: {_pct(recall['raw'])} (unicase: {_pct(recall['unicase'])}), " f"Precision: {_pct(precision['raw'])} (unicase: {_pct(precision['unicase'])}), Mean IoU: {_pct(mean_iou)}" ) def parse_args(): import argparse parser = argparse.ArgumentParser( description="DocTR end-to-end evaluation", formatter_class=argparse.ArgumentDefaultsHelpFormatter ) parser.add_argument("detection", type=str, help="Text detection model to use for analysis") parser.add_argument("recognition", type=str, help="Text recognition model to use for analysis") parser.add_argument("--iou", type=float, default=0.5, help="IoU threshold to match a pair of boxes") parser.add_argument("--dataset", type=str, default="FUNSD", help="choose a dataset: FUNSD, CORD") parser.add_argument("--img_folder", type=str, default=None, help="Only for local sets, path to images") parser.add_argument("--label_file", type=str, default=None, help="Only for local sets, path to labels") parser.add_argument("--rotation", dest="rotation", action="store_true", help="run rotated OCR + postprocessing") parser.add_argument("-b", "--batch_size", type=int, default=32, help="batch size for recognition") parser.add_argument("--size", type=int, default=1024, help="model input size, H = W") parser.add_argument("--keep_ratio", action="store_true", help="keep the aspect ratio of the input image") parser.add_argument("--symmetric_pad", action="store_true", help="pad the image symmetrically") parser.add_argument("--samples", type=int, default=None, help="evaluate only on the N first samples") parser.add_argument( "--eval-straight", action="store_true", help="evaluate on straight pages with straight bbox (to use the quick and light metric)", ) args = parser.parse_args() return args if __name__ == "__main__": args = parse_args() main(args)