Spaces:
Running
Running
File size: 6,688 Bytes
f3270e6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 |
# Copyright (C) 2021-2025, Mindee.
# This program is licensed under the Apache License 2.0.
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
import multiprocessing as mp
import os
import time
from pathlib import Path
import torch
from torch.utils.data import DataLoader, SequentialSampler
from torchvision.transforms import Normalize
from doctr.file_utils import CLASS_NAME
if os.getenv("TQDM_SLACK_TOKEN") and os.getenv("TQDM_SLACK_CHANNEL"):
from tqdm.contrib.slack import tqdm
else:
from tqdm.auto import tqdm
from doctr import datasets
from doctr import transforms as T
from doctr.models import detection
from doctr.utils.metrics import LocalizationConfusion
@torch.inference_mode()
def evaluate(model, val_loader, batch_transforms, val_metric, amp=False):
# Model in eval mode
model.eval()
# Reset val metric
val_metric.reset()
# Validation loop
val_loss, batch_cnt = 0, 0
for images, targets in tqdm(val_loader):
if torch.cuda.is_available():
images = images.cuda()
images = batch_transforms(images)
targets = [{CLASS_NAME: t} for t in targets]
if amp:
with torch.cuda.amp.autocast():
out = model(images, targets, return_preds=True)
else:
out = model(images, targets, return_preds=True)
# Compute metric
loc_preds = out["preds"]
for target, loc_pred in zip(targets, loc_preds):
for boxes_gt, boxes_pred in zip(target.values(), loc_pred.values()):
# Remove scores
val_metric.update(gts=boxes_gt, preds=boxes_pred[:, :-1])
val_loss += out["loss"].item()
batch_cnt += 1
val_loss /= batch_cnt
recall, precision, mean_iou = val_metric.summary()
return val_loss, recall, precision, mean_iou
def main(args):
slack_token = os.getenv("TQDM_SLACK_TOKEN")
slack_channel = os.getenv("TQDM_SLACK_CHANNEL")
pbar = tqdm(disable=False if slack_token and slack_channel else True)
if slack_token and slack_channel:
# Monkey patch tqdm write method to send messages directly to Slack
pbar.write = lambda msg: pbar.sio.client.chat_postMessage(channel=slack_channel, text=msg)
pbar.write(str(args))
if not isinstance(args.workers, int):
args.workers = min(16, mp.cpu_count())
torch.backends.cudnn.benchmark = True
# Load docTR model
model = detection.__dict__[args.arch](
pretrained=not isinstance(args.resume, str), assume_straight_pages=not args.rotation
).eval()
if isinstance(args.size, int):
input_shape = (args.size, args.size)
else:
input_shape = model.cfg["input_shape"][-2:]
mean, std = model.cfg["mean"], model.cfg["std"]
st = time.time()
ds = datasets.__dict__[args.dataset](
train=True,
download=True,
use_polygons=args.rotation,
detection_task=True,
sample_transforms=T.Resize(
input_shape, preserve_aspect_ratio=args.keep_ratio, symmetric_pad=args.symmetric_pad
),
)
# Monkeypatch
subfolder = ds.root.split("/")[-2:]
ds.root = str(Path(ds.root).parent.parent)
ds.data = [(os.path.join(*subfolder, name), target) for name, target in ds.data]
_ds = datasets.__dict__[args.dataset](
train=False,
download=True,
use_polygons=args.rotation,
detection_task=True,
sample_transforms=T.Resize(
input_shape, preserve_aspect_ratio=args.keep_ratio, symmetric_pad=args.symmetric_pad
),
)
subfolder = _ds.root.split("/")[-2:]
ds.data.extend([(os.path.join(*subfolder, name), target) for name, target in _ds.data])
test_loader = DataLoader(
ds,
batch_size=args.batch_size,
drop_last=False,
num_workers=args.workers,
sampler=SequentialSampler(ds),
pin_memory=torch.cuda.is_available(),
collate_fn=ds.collate_fn,
)
pbar.write(f"Test set loaded in {time.time() - st:.4}s ({len(ds)} samples in {len(test_loader)} batches)")
batch_transforms = Normalize(mean=mean, std=std)
# Resume weights
if isinstance(args.resume, str):
pbar.write(f"Resuming {args.resume}")
model.from_pretrained(args.resume)
# GPU
if isinstance(args.device, int):
if not torch.cuda.is_available():
raise AssertionError("PyTorch cannot access your GPU. Please investigate!")
if args.device >= torch.cuda.device_count():
raise ValueError("Invalid device index")
# Silent default switch to GPU if available
elif torch.cuda.is_available():
args.device = 0
else:
pbar.write("No accessible GPU, target device set to CPU.")
if torch.cuda.is_available():
torch.cuda.set_device(args.device)
model = model.cuda()
# Metrics
metric = LocalizationConfusion(use_polygons=args.rotation)
pbar.write("Running evaluation")
val_loss, recall, precision, mean_iou = evaluate(model, test_loader, batch_transforms, metric, amp=args.amp)
pbar.write(
f"Validation loss: {val_loss:.6} (Recall: {recall:.2%} | Precision: {precision:.2%} | Mean IoU: {mean_iou:.2%})"
)
def parse_args():
import argparse
parser = argparse.ArgumentParser(
description="docTR evaluation script for text detection (PyTorch)",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument("arch", type=str, help="text-detection model to evaluate")
parser.add_argument("--dataset", type=str, default="FUNSD", help="Dataset to evaluate on")
parser.add_argument("-b", "--batch_size", type=int, default=2, help="batch size for evaluation")
parser.add_argument("--device", default=None, type=int, help="device")
parser.add_argument("--size", type=int, default=None, help="model input size, H = W")
parser.add_argument("--keep_ratio", action="store_true", help="keep the aspect ratio of the input image")
parser.add_argument("--symmetric_pad", action="store_true", help="pad the image symmetrically")
parser.add_argument("-j", "--workers", type=int, default=None, help="number of workers used for dataloading")
parser.add_argument("--rotation", dest="rotation", action="store_true", help="inference with rotated bbox")
parser.add_argument("--resume", type=str, default=None, help="Checkpoint to resume")
parser.add_argument("--amp", dest="amp", help="Use Automatic Mixed Precision", action="store_true")
args = parser.parse_args()
return args
if __name__ == "__main__":
args = parse_args()
main(args)
|