File size: 8,110 Bytes
d7c9ee5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
"""PP-DocLayout-V3 layout model for the docling standard pipeline.

Runs PaddlePaddle PP-DocLayout-V3 locally via HuggingFace ``transformers``
to detect document layout elements and returns ``LayoutPrediction`` objects
that docling merges with its standard-pipeline output.
"""

from __future__ import annotations

import logging
import warnings
from typing import TYPE_CHECKING

import numpy as np
import torch
from docling.datamodel.base_models import BoundingBox, Cluster, LayoutPrediction, Page
from docling.models.base_layout_model import BaseLayoutModel
from docling.utils.accelerator_utils import decide_device
from docling.utils.layout_postprocessor import LayoutPostprocessor
from docling.utils.profiling import TimeRecorder
from docling_core.types.doc import DocItemLabel
from transformers import AutoImageProcessor, AutoModelForObjectDetection

from docling_pp_doc_layout.label_mapping import LABEL_MAP
from docling_pp_doc_layout.options import PPDocLayoutV3Options

if TYPE_CHECKING:
    from collections.abc import Sequence
    from pathlib import Path

    from docling.datamodel.accelerator_options import AcceleratorOptions
    from docling.datamodel.document import ConversionResult
    from docling.datamodel.pipeline_options import BaseLayoutOptions
    from PIL import Image

logger = logging.getLogger(__name__)


class PPDocLayoutV3Model(BaseLayoutModel):
    """Layout engine using PP-DocLayout-V3 via HuggingFace transformers."""

    def __init__(
        self,
        artifacts_path: Path | None,
        accelerator_options: AcceleratorOptions,
        options: PPDocLayoutV3Options,
        *,
        enable_remote_services: bool = False,  # noqa: ARG002
    ) -> None:
        self.options = options
        self.artifacts_path = artifacts_path
        self.accelerator_options = accelerator_options

        self._device = decide_device(accelerator_options.device)
        logger.info(
            "Loading PP-DocLayout-V3 model %s on device=%s",
            options.model_name,
            self._device,
        )

        self._image_processor = AutoImageProcessor.from_pretrained(
            options.model_name,
        )
        self._model = AutoModelForObjectDetection.from_pretrained(
            options.model_name,
        ).to(self._device)
        self._model.eval()

        self._id2label: dict[int, str] = self._model.config.id2label
        logger.info("PP-DocLayout-V3 model loaded successfully")

    @classmethod
    def get_options_type(cls) -> type[BaseLayoutOptions]:
        """Return the options class for this layout model."""
        return PPDocLayoutV3Options

    def _run_inference(
        self,
        images: list[Image.Image],
    ) -> list[list[dict]]:
        """Run PP-DocLayout-V3 on a batch of PIL images.

        Returns a list (per image) of lists of detection dicts with keys
        ``label``, ``confidence``, ``l``, ``t``, ``r``, ``b``.
        """
        inputs = self._image_processor(images=images, return_tensors="pt")
        inputs = {k: v.to(self._device) for k, v in inputs.items()}

        with torch.no_grad():
            outputs = self._model(**inputs)

        target_sizes = [img.size[::-1] for img in images]  # (height, width)
        results = self._image_processor.post_process_object_detection(
            outputs,
            target_sizes=target_sizes,
            threshold=self.options.confidence_threshold,
        )

        batch_detections: list[list[dict]] = []
        for result in results:
            detections: list[dict] = []

            polys = result.get("polygons") or result.get("polygon_points")
            if polys is None:
                polys = [None] * len(result["scores"])

            for score, label_id, box, poly in zip(
                result["scores"],
                result["labels"],
                result["boxes"],
                polys,
                strict=True,
            ):
                raw_label = self._id2label.get(label_id.item(), "text")
                doc_label = LABEL_MAP.get(raw_label, DocItemLabel.TEXT)

                if poly is not None and len(poly) > 0:
                    # Flatten or handle nested points to extract min/max
                    if isinstance(poly[0], int | float):
                        xs = poly[0::2]
                        ys = poly[1::2]
                    else:
                        xs = [pt[0] for pt in poly]
                        ys = [pt[1] for pt in poly]
                    x_min, x_max = min(xs), max(xs)
                    y_min, y_max = min(ys), max(ys)
                else:
                    x_min, y_min, x_max, y_max = box.tolist()

                detections.append({
                    "label": doc_label,
                    "confidence": score.item(),
                    "l": x_min,
                    "t": y_min,
                    "r": x_max,
                    "b": y_max,
                })
            batch_detections.append(detections)

        return batch_detections

    def predict_layout(
        self,
        conv_res: ConversionResult,
        pages: Sequence[Page],
    ) -> Sequence[LayoutPrediction]:
        """Detect layout regions for a batch of document pages."""
        pages = list(pages)

        valid_pages: list[Page] = []
        valid_images: list[Image.Image] = []
        is_page_valid: list[bool] = []

        for page in pages:
            if page._backend is None or not page._backend.is_valid():  # noqa: SLF001
                is_page_valid.append(False)
                continue
            if page.size is None:
                is_page_valid.append(False)
                continue
            page_image = page.get_image(scale=1.0)
            if page_image is None:
                is_page_valid.append(False)
                continue

            valid_pages.append(page)
            valid_images.append(page_image)
            is_page_valid.append(True)

        batch_detections: list[list[dict]] = []
        if valid_images:
            with TimeRecorder(conv_res, "layout"):
                bs = self.options.batch_size
                for i in range(0, len(valid_images), bs):
                    batch = valid_images[i : i + bs]
                    batch_detections.extend(self._run_inference(batch))

        layout_predictions: list[LayoutPrediction] = []
        valid_idx = 0

        for idx, page in enumerate(pages):
            if not is_page_valid[idx]:
                existing = page.predictions.layout or LayoutPrediction()
                layout_predictions.append(existing)
                continue

            detections = batch_detections[valid_idx]
            valid_idx += 1

            clusters: list[Cluster] = []
            for ix, det in enumerate(detections):
                cluster = Cluster(
                    id=ix,
                    label=det["label"],
                    confidence=det["confidence"],
                    bbox=BoundingBox(
                        l=det["l"],
                        t=det["t"],
                        r=det["r"],
                        b=det["b"],
                    ),
                    cells=[],
                )
                clusters.append(cluster)

            processed_clusters, processed_cells = LayoutPostprocessor(page, clusters, self.options).postprocess()

            with warnings.catch_warnings():
                warnings.filterwarnings(
                    "ignore",
                    "Mean of empty slice|invalid value encountered in scalar divide",
                    RuntimeWarning,
                    "numpy",
                )
                conv_res.confidence.pages[page.page_no].layout_score = float(
                    np.mean([c.confidence for c in processed_clusters])
                )
                conv_res.confidence.pages[page.page_no].ocr_score = float(
                    np.mean([c.confidence for c in processed_cells if c.from_ocr])
                )

            prediction = LayoutPrediction(clusters=processed_clusters)
            layout_predictions.append(prediction)

        return layout_predictions