File size: 3,195 Bytes
f3270e6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
# Copyright (C) 2021-2025, Mindee.

# This program is licensed under the Apache License 2.0.
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.

import argparse
import json
import os
from pathlib import Path

from tqdm import tqdm

from doctr.io import DocumentFile
from doctr.models import detection, ocr_predictor

IMAGE_FILE_EXTENSIONS = [".jpeg", ".jpg", ".png", ".tif", ".tiff", ".bmp"]
OTHER_EXTENSIONS = [".pdf"]


def _process_file(model, file_path: Path, out_format: str) -> None:
    if out_format not in ["txt", "json", "xml"]:
        raise ValueError(f"Unsupported output format: {out_format}")

    if os.path.splitext(file_path)[1] in IMAGE_FILE_EXTENSIONS:
        doc = DocumentFile.from_images([file_path])
    elif os.path.splitext(file_path)[1] in OTHER_EXTENSIONS:
        doc = DocumentFile.from_pdf(file_path)
    else:
        print(f"Skip unsupported file type: {file_path}")

    out = model(doc)

    if out_format == "json":
        output = json.dumps(out.export(), indent=2)
    elif out_format == "txt":
        output = out.render()
    elif out_format == "xml":
        output = out.export_as_xml()

    path = Path("output").joinpath(file_path.stem + "." + out_format)
    if out_format == "xml":
        for i, (xml_bytes, xml_tree) in enumerate(output):
            path = Path("output").joinpath(file_path.stem + f"_{i}." + out_format)
            xml_tree.write(path, encoding="utf-8", xml_declaration=True)
    else:
        with open(path, "w") as f:
            f.write(output)


def main(args):
    detection_model = detection.__dict__[args.detection](
        pretrained=True,
        bin_thresh=args.bin_thresh,
        box_thresh=args.box_thresh,
    )
    model = ocr_predictor(detection_model, args.recognition, pretrained=True)
    path = Path(args.path)

    os.makedirs(name="output", exist_ok=True)

    if path.is_dir():
        to_process = [
            f for f in path.iterdir() if str(f).lower().endswith(tuple(IMAGE_FILE_EXTENSIONS + OTHER_EXTENSIONS))
        ]
        for file_path in tqdm(to_process):
            _process_file(model, file_path, args.format)
    else:
        _process_file(model, path, args.format)


def parse_args():
    parser = argparse.ArgumentParser(
        description="DocTR text detection",
        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
    )
    parser.add_argument("path", type=str, help="Path to process: PDF, image, directory")
    parser.add_argument("--detection", type=str, default="fast_base", help="Text detection model to use for analysis")
    parser.add_argument("--bin-thresh", type=float, default=0.3, help="Binarization threshold for the detection model.")
    parser.add_argument("--box-thresh", type=float, default=0.1, help="Threshold for the detection boxes.")
    parser.add_argument(
        "--recognition", type=str, default="crnn_vgg16_bn", help="Text recognition model to use for analysis"
    )
    parser.add_argument("-f", "--format", choices=["txt", "json", "xml"], default="txt", help="Output format")
    return parser.parse_args()


if __name__ == "__main__":
    parsed_args = parse_args()
    main(parsed_args)