MikeTrizna's picture
Upload folder using huggingface_hub
f3270e6 verified
# Copyright (C) 2021-2025, Mindee.
# This program is licensed under the Apache License 2.0.
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
import multiprocessing as mp
import os
import time
from pathlib import Path
import torch
from torch.utils.data import DataLoader, SequentialSampler
from torchvision.transforms import Normalize
from doctr.file_utils import CLASS_NAME
if os.getenv("TQDM_SLACK_TOKEN") and os.getenv("TQDM_SLACK_CHANNEL"):
from tqdm.contrib.slack import tqdm
else:
from tqdm.auto import tqdm
from doctr import datasets
from doctr import transforms as T
from doctr.models import detection
from doctr.utils.metrics import LocalizationConfusion
@torch.inference_mode()
def evaluate(model, val_loader, batch_transforms, val_metric, amp=False):
# Model in eval mode
model.eval()
# Reset val metric
val_metric.reset()
# Validation loop
val_loss, batch_cnt = 0, 0
for images, targets in tqdm(val_loader):
if torch.cuda.is_available():
images = images.cuda()
images = batch_transforms(images)
targets = [{CLASS_NAME: t} for t in targets]
if amp:
with torch.cuda.amp.autocast():
out = model(images, targets, return_preds=True)
else:
out = model(images, targets, return_preds=True)
# Compute metric
loc_preds = out["preds"]
for target, loc_pred in zip(targets, loc_preds):
for boxes_gt, boxes_pred in zip(target.values(), loc_pred.values()):
# Remove scores
val_metric.update(gts=boxes_gt, preds=boxes_pred[:, :-1])
val_loss += out["loss"].item()
batch_cnt += 1
val_loss /= batch_cnt
recall, precision, mean_iou = val_metric.summary()
return val_loss, recall, precision, mean_iou
def main(args):
slack_token = os.getenv("TQDM_SLACK_TOKEN")
slack_channel = os.getenv("TQDM_SLACK_CHANNEL")
pbar = tqdm(disable=False if slack_token and slack_channel else True)
if slack_token and slack_channel:
# Monkey patch tqdm write method to send messages directly to Slack
pbar.write = lambda msg: pbar.sio.client.chat_postMessage(channel=slack_channel, text=msg)
pbar.write(str(args))
if not isinstance(args.workers, int):
args.workers = min(16, mp.cpu_count())
torch.backends.cudnn.benchmark = True
# Load docTR model
model = detection.__dict__[args.arch](
pretrained=not isinstance(args.resume, str), assume_straight_pages=not args.rotation
).eval()
if isinstance(args.size, int):
input_shape = (args.size, args.size)
else:
input_shape = model.cfg["input_shape"][-2:]
mean, std = model.cfg["mean"], model.cfg["std"]
st = time.time()
ds = datasets.__dict__[args.dataset](
train=True,
download=True,
use_polygons=args.rotation,
detection_task=True,
sample_transforms=T.Resize(
input_shape, preserve_aspect_ratio=args.keep_ratio, symmetric_pad=args.symmetric_pad
),
)
# Monkeypatch
subfolder = ds.root.split("/")[-2:]
ds.root = str(Path(ds.root).parent.parent)
ds.data = [(os.path.join(*subfolder, name), target) for name, target in ds.data]
_ds = datasets.__dict__[args.dataset](
train=False,
download=True,
use_polygons=args.rotation,
detection_task=True,
sample_transforms=T.Resize(
input_shape, preserve_aspect_ratio=args.keep_ratio, symmetric_pad=args.symmetric_pad
),
)
subfolder = _ds.root.split("/")[-2:]
ds.data.extend([(os.path.join(*subfolder, name), target) for name, target in _ds.data])
test_loader = DataLoader(
ds,
batch_size=args.batch_size,
drop_last=False,
num_workers=args.workers,
sampler=SequentialSampler(ds),
pin_memory=torch.cuda.is_available(),
collate_fn=ds.collate_fn,
)
pbar.write(f"Test set loaded in {time.time() - st:.4}s ({len(ds)} samples in {len(test_loader)} batches)")
batch_transforms = Normalize(mean=mean, std=std)
# Resume weights
if isinstance(args.resume, str):
pbar.write(f"Resuming {args.resume}")
model.from_pretrained(args.resume)
# GPU
if isinstance(args.device, int):
if not torch.cuda.is_available():
raise AssertionError("PyTorch cannot access your GPU. Please investigate!")
if args.device >= torch.cuda.device_count():
raise ValueError("Invalid device index")
# Silent default switch to GPU if available
elif torch.cuda.is_available():
args.device = 0
else:
pbar.write("No accessible GPU, target device set to CPU.")
if torch.cuda.is_available():
torch.cuda.set_device(args.device)
model = model.cuda()
# Metrics
metric = LocalizationConfusion(use_polygons=args.rotation)
pbar.write("Running evaluation")
val_loss, recall, precision, mean_iou = evaluate(model, test_loader, batch_transforms, metric, amp=args.amp)
pbar.write(
f"Validation loss: {val_loss:.6} (Recall: {recall:.2%} | Precision: {precision:.2%} | Mean IoU: {mean_iou:.2%})"
)
def parse_args():
import argparse
parser = argparse.ArgumentParser(
description="docTR evaluation script for text detection (PyTorch)",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument("arch", type=str, help="text-detection model to evaluate")
parser.add_argument("--dataset", type=str, default="FUNSD", help="Dataset to evaluate on")
parser.add_argument("-b", "--batch_size", type=int, default=2, help="batch size for evaluation")
parser.add_argument("--device", default=None, type=int, help="device")
parser.add_argument("--size", type=int, default=None, help="model input size, H = W")
parser.add_argument("--keep_ratio", action="store_true", help="keep the aspect ratio of the input image")
parser.add_argument("--symmetric_pad", action="store_true", help="pad the image symmetrically")
parser.add_argument("-j", "--workers", type=int, default=None, help="number of workers used for dataloading")
parser.add_argument("--rotation", dest="rotation", action="store_true", help="inference with rotated bbox")
parser.add_argument("--resume", type=str, default=None, help="Checkpoint to resume")
parser.add_argument("--amp", dest="amp", help="Use Automatic Mixed Precision", action="store_true")
args = parser.parse_args()
return args
if __name__ == "__main__":
args = parse_args()
main(args)