Spaces:
Running
Running
| # Copyright (C) 2021-2025, Mindee. | |
| # This program is licensed under the Apache License 2.0. | |
| # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details. | |
| import multiprocessing as mp | |
| import os | |
| import time | |
| import torch | |
| from torch.utils.data import DataLoader, SequentialSampler | |
| from torchvision.transforms import Normalize | |
| if os.getenv("TQDM_SLACK_TOKEN") and os.getenv("TQDM_SLACK_CHANNEL"): | |
| from tqdm.contrib.slack import tqdm | |
| else: | |
| from tqdm.auto import tqdm | |
| from doctr import datasets | |
| from doctr import transforms as T | |
| from doctr.datasets import VOCABS | |
| from doctr.models import recognition | |
| from doctr.utils.metrics import TextMatch | |
| def evaluate(model, val_loader, batch_transforms, val_metric, amp=False): | |
| # Model in eval mode | |
| model.eval() | |
| # Reset val metric | |
| val_metric.reset() | |
| # Validation loop | |
| val_loss, batch_cnt = 0, 0 | |
| pbar = tqdm(val_loader) | |
| for images, targets in pbar: | |
| try: | |
| if torch.cuda.is_available(): | |
| images = images.cuda() | |
| images = batch_transforms(images) | |
| if amp: | |
| with torch.cuda.amp.autocast(): | |
| out = model(images, targets, return_preds=True) | |
| else: | |
| out = model(images, targets, return_preds=True) | |
| # Compute metric | |
| if len(out["preds"]): | |
| words, _ = zip(*out["preds"]) | |
| else: | |
| words = [] | |
| val_metric.update(targets, words) | |
| val_loss += out["loss"].item() | |
| batch_cnt += 1 | |
| except ValueError: | |
| pbar.write(f"unexpected symbol/s in targets:\n{targets} \n--> skip batch") | |
| continue | |
| val_loss /= batch_cnt | |
| result = val_metric.summary() | |
| return val_loss, result["raw"], result["unicase"] | |
| def main(args): | |
| slack_token = os.getenv("TQDM_SLACK_TOKEN") | |
| slack_channel = os.getenv("TQDM_SLACK_CHANNEL") | |
| pbar = tqdm(disable=False if slack_token and slack_channel else True) | |
| if slack_token and slack_channel: | |
| # Monkey patch tqdm write method to send messages directly to Slack | |
| pbar.write = lambda msg: pbar.sio.client.chat_postMessage(channel=slack_channel, text=msg) | |
| pbar.write(str(args)) | |
| torch.backends.cudnn.benchmark = True | |
| if not isinstance(args.workers, int): | |
| args.workers = min(16, mp.cpu_count()) | |
| # Load doctr model | |
| model = recognition.__dict__[args.arch]( | |
| pretrained=True if args.resume is None else False, | |
| input_shape=(3, args.input_size, 4 * args.input_size), | |
| vocab=VOCABS[args.vocab], | |
| ).eval() | |
| # Resume weights | |
| if isinstance(args.resume, str): | |
| pbar.write(f"Resuming {args.resume}") | |
| model.from_pretrained(args.resume) | |
| st = time.time() | |
| ds = datasets.__dict__[args.dataset]( | |
| train=True, | |
| download=True, | |
| recognition_task=True, | |
| use_polygons=args.regular, | |
| img_transforms=T.Resize((args.input_size, 4 * args.input_size), preserve_aspect_ratio=True), | |
| ) | |
| _ds = datasets.__dict__[args.dataset]( | |
| train=False, | |
| download=True, | |
| recognition_task=True, | |
| use_polygons=args.regular, | |
| img_transforms=T.Resize((args.input_size, 4 * args.input_size), preserve_aspect_ratio=True), | |
| ) | |
| ds.data.extend((np_img, target) for np_img, target in _ds.data) | |
| test_loader = DataLoader( | |
| ds, | |
| batch_size=args.batch_size, | |
| drop_last=False, | |
| num_workers=args.workers, | |
| sampler=SequentialSampler(ds), | |
| pin_memory=torch.cuda.is_available(), | |
| collate_fn=ds.collate_fn, | |
| ) | |
| pbar.write(f"Test set loaded in {time.time() - st:.4}s ({len(ds)} samples in {len(test_loader)} batches)") | |
| mean, std = model.cfg["mean"], model.cfg["std"] | |
| batch_transforms = Normalize(mean=mean, std=std) | |
| # Metrics | |
| val_metric = TextMatch() | |
| # GPU | |
| if isinstance(args.device, int): | |
| if not torch.cuda.is_available(): | |
| raise AssertionError("PyTorch cannot access your GPU. Please investigate!") | |
| if args.device >= torch.cuda.device_count(): | |
| raise ValueError("Invalid device index") | |
| # Silent default switch to GPU if available | |
| elif torch.cuda.is_available(): | |
| args.device = 0 | |
| else: | |
| pbar.write("No accessible GPU, target device set to CPU.") | |
| if torch.cuda.is_available(): | |
| torch.cuda.set_device(args.device) | |
| model = model.cuda() | |
| pbar.write("Running evaluation") | |
| val_loss, exact_match, partial_match = evaluate(model, test_loader, batch_transforms, val_metric, amp=args.amp) | |
| pbar.write(f"Validation loss: {val_loss:.6} (Exact: {exact_match:.2%} | Partial: {partial_match:.2%})") | |
| def parse_args(): | |
| import argparse | |
| parser = argparse.ArgumentParser( | |
| description="docTR evaluation script for text recognition (PyTorch)", | |
| formatter_class=argparse.ArgumentDefaultsHelpFormatter, | |
| ) | |
| parser.add_argument("arch", type=str, help="text-recognition model to evaluate") | |
| parser.add_argument("--vocab", type=str, default="french", help="Vocab to be used for evaluation") | |
| parser.add_argument("--dataset", type=str, default="FUNSD", help="Dataset to evaluate on") | |
| parser.add_argument("--device", default=None, type=int, help="device") | |
| parser.add_argument("-b", "--batch_size", type=int, default=1, help="batch size for evaluation") | |
| parser.add_argument("--input_size", type=int, default=32, help="input size H for the model, W = 4*H") | |
| parser.add_argument("-j", "--workers", type=int, default=None, help="number of workers used for dataloading") | |
| parser.add_argument( | |
| "--only_regular", dest="regular", action="store_true", help="test set contains only regular text" | |
| ) | |
| parser.add_argument("--resume", type=str, default=None, help="Checkpoint to resume") | |
| parser.add_argument("--amp", dest="amp", help="Use Automatic Mixed Precision", action="store_true") | |
| args = parser.parse_args() | |
| return args | |
| if __name__ == "__main__": | |
| args = parse_args() | |
| main(args) | |