# Copyright (C) 2021-2025, Mindee. # This program is licensed under the Apache License 2.0. # See LICENSE or go to for full license details. import multiprocessing as mp import os import time import torch from torch.utils.data import DataLoader, SequentialSampler from torchvision.transforms import Normalize if os.getenv("TQDM_SLACK_TOKEN") and os.getenv("TQDM_SLACK_CHANNEL"): from tqdm.contrib.slack import tqdm else: from tqdm.auto import tqdm from doctr import datasets from doctr import transforms as T from doctr.datasets import VOCABS from doctr.models import recognition from doctr.utils.metrics import TextMatch @torch.inference_mode() def evaluate(model, val_loader, batch_transforms, val_metric, amp=False): # Model in eval mode model.eval() # Reset val metric val_metric.reset() # Validation loop val_loss, batch_cnt = 0, 0 pbar = tqdm(val_loader) for images, targets in pbar: try: if torch.cuda.is_available(): images = images.cuda() images = batch_transforms(images) if amp: with torch.cuda.amp.autocast(): out = model(images, targets, return_preds=True) else: out = model(images, targets, return_preds=True) # Compute metric if len(out["preds"]): words, _ = zip(*out["preds"]) else: words = [] val_metric.update(targets, words) val_loss += out["loss"].item() batch_cnt += 1 except ValueError: pbar.write(f"unexpected symbol/s in targets:\n{targets} \n--> skip batch") continue val_loss /= batch_cnt result = val_metric.summary() return val_loss, result["raw"], result["unicase"] def main(args): slack_token = os.getenv("TQDM_SLACK_TOKEN") slack_channel = os.getenv("TQDM_SLACK_CHANNEL") pbar = tqdm(disable=False if slack_token and slack_channel else True) if slack_token and slack_channel: # Monkey patch tqdm write method to send messages directly to Slack pbar.write = lambda msg: pbar.sio.client.chat_postMessage(channel=slack_channel, text=msg) pbar.write(str(args)) torch.backends.cudnn.benchmark = True if not isinstance(args.workers, int): args.workers = min(16, mp.cpu_count()) # Load doctr model model = recognition.__dict__[args.arch]( pretrained=True if args.resume is None else False, input_shape=(3, args.input_size, 4 * args.input_size), vocab=VOCABS[args.vocab], ).eval() # Resume weights if isinstance(args.resume, str): pbar.write(f"Resuming {args.resume}") model.from_pretrained(args.resume) st = time.time() ds = datasets.__dict__[args.dataset]( train=True, download=True, recognition_task=True, use_polygons=args.regular, img_transforms=T.Resize((args.input_size, 4 * args.input_size), preserve_aspect_ratio=True), ) _ds = datasets.__dict__[args.dataset]( train=False, download=True, recognition_task=True, use_polygons=args.regular, img_transforms=T.Resize((args.input_size, 4 * args.input_size), preserve_aspect_ratio=True), ) ds.data.extend((np_img, target) for np_img, target in _ds.data) test_loader = DataLoader( ds, batch_size=args.batch_size, drop_last=False, num_workers=args.workers, sampler=SequentialSampler(ds), pin_memory=torch.cuda.is_available(), collate_fn=ds.collate_fn, ) pbar.write(f"Test set loaded in {time.time() - st:.4}s ({len(ds)} samples in {len(test_loader)} batches)") mean, std = model.cfg["mean"], model.cfg["std"] batch_transforms = Normalize(mean=mean, std=std) # Metrics val_metric = TextMatch() # GPU if isinstance(args.device, int): if not torch.cuda.is_available(): raise AssertionError("PyTorch cannot access your GPU. Please investigate!") if args.device >= torch.cuda.device_count(): raise ValueError("Invalid device index") # Silent default switch to GPU if available elif torch.cuda.is_available(): args.device = 0 else: pbar.write("No accessible GPU, target device set to CPU.") if torch.cuda.is_available(): torch.cuda.set_device(args.device) model = model.cuda() pbar.write("Running evaluation") val_loss, exact_match, partial_match = evaluate(model, test_loader, batch_transforms, val_metric, amp=args.amp) pbar.write(f"Validation loss: {val_loss:.6} (Exact: {exact_match:.2%} | Partial: {partial_match:.2%})") def parse_args(): import argparse parser = argparse.ArgumentParser( description="docTR evaluation script for text recognition (PyTorch)", formatter_class=argparse.ArgumentDefaultsHelpFormatter, ) parser.add_argument("arch", type=str, help="text-recognition model to evaluate") parser.add_argument("--vocab", type=str, default="french", help="Vocab to be used for evaluation") parser.add_argument("--dataset", type=str, default="FUNSD", help="Dataset to evaluate on") parser.add_argument("--device", default=None, type=int, help="device") parser.add_argument("-b", "--batch_size", type=int, default=1, help="batch size for evaluation") parser.add_argument("--input_size", type=int, default=32, help="input size H for the model, W = 4*H") parser.add_argument("-j", "--workers", type=int, default=None, help="number of workers used for dataloading") parser.add_argument( "--only_regular", dest="regular", action="store_true", help="test set contains only regular text" ) parser.add_argument("--resume", type=str, default=None, help="Checkpoint to resume") parser.add_argument("--amp", dest="amp", help="Use Automatic Mixed Precision", action="store_true") args = parser.parse_args() return args if __name__ == "__main__": args = parse_args() main(args)