MikeTrizna's picture
Upload folder using huggingface_hub
f3270e6 verified
# Copyright (C) 2021-2025, Mindee.
# This program is licensed under the Apache License 2.0.
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
import multiprocessing as mp
import os
import time
import torch
from torch.utils.data import DataLoader, SequentialSampler
from torchvision.transforms import Normalize
if os.getenv("TQDM_SLACK_TOKEN") and os.getenv("TQDM_SLACK_CHANNEL"):
from tqdm.contrib.slack import tqdm
else:
from tqdm.auto import tqdm
from doctr import datasets
from doctr import transforms as T
from doctr.datasets import VOCABS
from doctr.models import recognition
from doctr.utils.metrics import TextMatch
@torch.inference_mode()
def evaluate(model, val_loader, batch_transforms, val_metric, amp=False):
# Model in eval mode
model.eval()
# Reset val metric
val_metric.reset()
# Validation loop
val_loss, batch_cnt = 0, 0
pbar = tqdm(val_loader)
for images, targets in pbar:
try:
if torch.cuda.is_available():
images = images.cuda()
images = batch_transforms(images)
if amp:
with torch.cuda.amp.autocast():
out = model(images, targets, return_preds=True)
else:
out = model(images, targets, return_preds=True)
# Compute metric
if len(out["preds"]):
words, _ = zip(*out["preds"])
else:
words = []
val_metric.update(targets, words)
val_loss += out["loss"].item()
batch_cnt += 1
except ValueError:
pbar.write(f"unexpected symbol/s in targets:\n{targets} \n--> skip batch")
continue
val_loss /= batch_cnt
result = val_metric.summary()
return val_loss, result["raw"], result["unicase"]
def main(args):
slack_token = os.getenv("TQDM_SLACK_TOKEN")
slack_channel = os.getenv("TQDM_SLACK_CHANNEL")
pbar = tqdm(disable=False if slack_token and slack_channel else True)
if slack_token and slack_channel:
# Monkey patch tqdm write method to send messages directly to Slack
pbar.write = lambda msg: pbar.sio.client.chat_postMessage(channel=slack_channel, text=msg)
pbar.write(str(args))
torch.backends.cudnn.benchmark = True
if not isinstance(args.workers, int):
args.workers = min(16, mp.cpu_count())
# Load doctr model
model = recognition.__dict__[args.arch](
pretrained=True if args.resume is None else False,
input_shape=(3, args.input_size, 4 * args.input_size),
vocab=VOCABS[args.vocab],
).eval()
# Resume weights
if isinstance(args.resume, str):
pbar.write(f"Resuming {args.resume}")
model.from_pretrained(args.resume)
st = time.time()
ds = datasets.__dict__[args.dataset](
train=True,
download=True,
recognition_task=True,
use_polygons=args.regular,
img_transforms=T.Resize((args.input_size, 4 * args.input_size), preserve_aspect_ratio=True),
)
_ds = datasets.__dict__[args.dataset](
train=False,
download=True,
recognition_task=True,
use_polygons=args.regular,
img_transforms=T.Resize((args.input_size, 4 * args.input_size), preserve_aspect_ratio=True),
)
ds.data.extend((np_img, target) for np_img, target in _ds.data)
test_loader = DataLoader(
ds,
batch_size=args.batch_size,
drop_last=False,
num_workers=args.workers,
sampler=SequentialSampler(ds),
pin_memory=torch.cuda.is_available(),
collate_fn=ds.collate_fn,
)
pbar.write(f"Test set loaded in {time.time() - st:.4}s ({len(ds)} samples in {len(test_loader)} batches)")
mean, std = model.cfg["mean"], model.cfg["std"]
batch_transforms = Normalize(mean=mean, std=std)
# Metrics
val_metric = TextMatch()
# GPU
if isinstance(args.device, int):
if not torch.cuda.is_available():
raise AssertionError("PyTorch cannot access your GPU. Please investigate!")
if args.device >= torch.cuda.device_count():
raise ValueError("Invalid device index")
# Silent default switch to GPU if available
elif torch.cuda.is_available():
args.device = 0
else:
pbar.write("No accessible GPU, target device set to CPU.")
if torch.cuda.is_available():
torch.cuda.set_device(args.device)
model = model.cuda()
pbar.write("Running evaluation")
val_loss, exact_match, partial_match = evaluate(model, test_loader, batch_transforms, val_metric, amp=args.amp)
pbar.write(f"Validation loss: {val_loss:.6} (Exact: {exact_match:.2%} | Partial: {partial_match:.2%})")
def parse_args():
import argparse
parser = argparse.ArgumentParser(
description="docTR evaluation script for text recognition (PyTorch)",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument("arch", type=str, help="text-recognition model to evaluate")
parser.add_argument("--vocab", type=str, default="french", help="Vocab to be used for evaluation")
parser.add_argument("--dataset", type=str, default="FUNSD", help="Dataset to evaluate on")
parser.add_argument("--device", default=None, type=int, help="device")
parser.add_argument("-b", "--batch_size", type=int, default=1, help="batch size for evaluation")
parser.add_argument("--input_size", type=int, default=32, help="input size H for the model, W = 4*H")
parser.add_argument("-j", "--workers", type=int, default=None, help="number of workers used for dataloading")
parser.add_argument(
"--only_regular", dest="regular", action="store_true", help="test set contains only regular text"
)
parser.add_argument("--resume", type=str, default=None, help="Checkpoint to resume")
parser.add_argument("--amp", dest="amp", help="Use Automatic Mixed Precision", action="store_true")
args = parser.parse_args()
return args
if __name__ == "__main__":
args = parse_args()
main(args)