MikeTrizna's picture
Upload folder using huggingface_hub
f3270e6 verified
# Copyright (C) 2021-2025, Mindee.
# This program is licensed under the Apache License 2.0.
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
import numpy as np
import torch
from tqdm import tqdm
from doctr import datasets
from doctr import transforms as T
from doctr.models import ocr_predictor
from doctr.utils.geometry import extract_crops, extract_rcrops
from doctr.utils.metrics import LocalizationConfusion, OCRMetric, TextMatch
def _pct(val):
return "N/A" if val is None else f"{val:.2%}"
def main(args):
if not args.rotation:
args.eval_straight = True
input_shape = (args.size, args.size)
# We define a transformation function which does transform the annotation
# to the required format for the Resize transformation
def _transform(img, target):
boxes = target["boxes"]
transformed_img, transformed_boxes = T.Resize(
input_shape, preserve_aspect_ratio=args.keep_ratio, symmetric_pad=args.symmetric_pad
)(img, boxes)
return transformed_img, {"boxes": transformed_boxes, "labels": target["labels"]}
predictor = ocr_predictor(
args.detection,
args.recognition,
pretrained=True,
reco_bs=args.batch_size,
preserve_aspect_ratio=False, # we handle the transformation directly in the dataset so this is set to False
symmetric_pad=False, # we handle the transformation directly in the dataset so this is set to False
assume_straight_pages=not args.rotation,
)
if torch.cuda.is_available():
predictor = predictor.cuda()
if args.img_folder and args.label_file:
testset = datasets.OCRDataset(
img_folder=args.img_folder,
label_file=args.label_file,
sample_transforms=_transform,
)
sets = [testset]
else:
train_set = datasets.__dict__[args.dataset](
train=True,
download=True,
use_polygons=not args.eval_straight,
sample_transforms=_transform,
)
val_set = datasets.__dict__[args.dataset](
train=False,
download=True,
use_polygons=not args.eval_straight,
sample_transforms=_transform,
)
sets = [train_set, val_set]
reco_metric = TextMatch()
det_metric = LocalizationConfusion(iou_thresh=args.iou, use_polygons=not args.eval_straight)
e2e_metric = OCRMetric(iou_thresh=args.iou, use_polygons=not args.eval_straight)
sample_idx = 0
extraction_fn = extract_crops if args.eval_straight else extract_rcrops
for dataset in sets:
for page, target in tqdm(dataset):
if isinstance(page, torch.Tensor):
page = np.transpose(page.numpy(), (1, 2, 0))
# GT
gt_boxes = target["boxes"]
gt_labels = target["labels"]
if args.img_folder and args.label_file:
x, y, w, h = gt_boxes[:, 0], gt_boxes[:, 1], gt_boxes[:, 2], gt_boxes[:, 3]
xmin, ymin = np.clip(x - w / 2, 0, 1), np.clip(y - h / 2, 0, 1)
xmax, ymax = np.clip(x + w / 2, 0, 1), np.clip(y + h / 2, 0, 1)
gt_boxes = np.stack([xmin, ymin, xmax, ymax], axis=-1)
# Forward
with torch.no_grad():
out = predictor(page[None, ...])
crops = extraction_fn(page, gt_boxes)
reco_out = predictor.reco_predictor(crops)
if len(reco_out):
reco_words, _ = zip(*reco_out)
else:
reco_words = []
# Unpack preds
pred_boxes = []
pred_labels = []
for page in out.pages:
height, width = page.dimensions
for block in page.blocks:
for line in block.lines:
for word in line.words:
if not args.rotation:
(a, b), (c, d) = word.geometry
else:
(
[x1, y1],
[x2, y2],
[x3, y3],
[x4, y4],
) = word.geometry
if np.issubdtype(gt_boxes.dtype, np.integer):
if not args.rotation:
pred_boxes.append([
int(a * width),
int(b * height),
int(c * width),
int(d * height),
])
else:
if args.eval_straight:
pred_boxes.append([
int(width * min(x1, x2, x3, x4)),
int(height * min(y1, y2, y3, y4)),
int(width * max(x1, x2, x3, x4)),
int(height * max(y1, y2, y3, y4)),
])
else:
pred_boxes.append([
[int(x1 * width), int(y1 * height)],
[int(x2 * width), int(y2 * height)],
[int(x3 * width), int(y3 * height)],
[int(x4 * width), int(y4 * height)],
])
else:
if not args.rotation:
pred_boxes.append([a, b, c, d])
else:
if args.eval_straight:
pred_boxes.append([
min(x1, x2, x3, x4),
min(y1, y2, y3, y4),
max(x1, x2, x3, x4),
max(y1, y2, y3, y4),
])
else:
pred_boxes.append([[x1, y1], [x2, y2], [x3, y3], [x4, y4]])
pred_labels.append(word.value)
# Update the metric
det_metric.update(gt_boxes, np.asarray(pred_boxes))
reco_metric.update(gt_labels, reco_words)
e2e_metric.update(gt_boxes, np.asarray(pred_boxes), gt_labels, pred_labels)
# Loop break
sample_idx += 1
if isinstance(args.samples, int) and args.samples == sample_idx:
break
if isinstance(args.samples, int) and args.samples == sample_idx:
break
# Unpack aggregated metrics
print(
f"Model Evaluation (model= {args.detection} + {args.recognition}, "
f"dataset={'OCRDataset' if args.img_folder else args.dataset})"
)
recall, precision, mean_iou = det_metric.summary()
print(f"Text Detection - Recall: {_pct(recall)}, Precision: {_pct(precision)}, Mean IoU: {_pct(mean_iou)}")
acc = reco_metric.summary()
print(f"Text Recognition - Accuracy: {_pct(acc['raw'])} (unicase: {_pct(acc['unicase'])})")
recall, precision, mean_iou = e2e_metric.summary()
print(
f"OCR - Recall: {_pct(recall['raw'])} (unicase: {_pct(recall['unicase'])}), "
f"Precision: {_pct(precision['raw'])} (unicase: {_pct(precision['unicase'])}), Mean IoU: {_pct(mean_iou)}"
)
def parse_args():
import argparse
parser = argparse.ArgumentParser(
description="DocTR end-to-end evaluation", formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument("detection", type=str, help="Text detection model to use for analysis")
parser.add_argument("recognition", type=str, help="Text recognition model to use for analysis")
parser.add_argument("--iou", type=float, default=0.5, help="IoU threshold to match a pair of boxes")
parser.add_argument("--dataset", type=str, default="FUNSD", help="choose a dataset: FUNSD, CORD")
parser.add_argument("--img_folder", type=str, default=None, help="Only for local sets, path to images")
parser.add_argument("--label_file", type=str, default=None, help="Only for local sets, path to labels")
parser.add_argument("--rotation", dest="rotation", action="store_true", help="run rotated OCR + postprocessing")
parser.add_argument("-b", "--batch_size", type=int, default=32, help="batch size for recognition")
parser.add_argument("--size", type=int, default=1024, help="model input size, H = W")
parser.add_argument("--keep_ratio", action="store_true", help="keep the aspect ratio of the input image")
parser.add_argument("--symmetric_pad", action="store_true", help="pad the image symmetrically")
parser.add_argument("--samples", type=int, default=None, help="evaluate only on the N first samples")
parser.add_argument(
"--eval-straight",
action="store_true",
help="evaluate on straight pages with straight bbox (to use the quick and light metric)",
)
args = parser.parse_args()
return args
if __name__ == "__main__":
args = parse_args()
main(args)