|
|
"""Simple DocLayout model for inference.""" |
|
|
|
|
|
import json |
|
|
from pathlib import Path |
|
|
from typing import Dict, List, Union |
|
|
|
|
|
import numpy as np |
|
|
from PIL import Image |
|
|
from ultralytics import YOLO |
|
|
|
|
|
|
|
|
class DocLayoutModel: |
|
|
""" |
|
|
Document layout detection model. |
|
|
|
|
|
Examples |
|
|
-------- |
|
|
>>> model = DocLayoutModel("model.pt") |
|
|
>>> results = model.predict("document.png") |
|
|
>>> for det in results: |
|
|
... print(f"{det['class_name']}: {det['confidence']:.2f}") |
|
|
""" |
|
|
|
|
|
|
|
|
DOCSTRUCTBENCH_CLASSES = { |
|
|
0: "title", |
|
|
1: "plain_text", |
|
|
2: "abandon", |
|
|
3: "figure", |
|
|
4: "figure_caption", |
|
|
5: "table", |
|
|
6: "table_caption", |
|
|
7: "table_footnote", |
|
|
8: "isolate_formula", |
|
|
9: "formula_caption", |
|
|
} |
|
|
|
|
|
DOCLAYNET_CLASSES = { |
|
|
0: "Caption", |
|
|
1: "Footnote", |
|
|
2: "Formula", |
|
|
3: "List-item", |
|
|
4: "Page-footer", |
|
|
5: "Page-header", |
|
|
6: "Picture", |
|
|
7: "Section-header", |
|
|
8: "Table", |
|
|
9: "Text", |
|
|
10: "Title", |
|
|
} |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
weights_path: Union[str, Path], |
|
|
config_path: Union[str, Path, None] = None, |
|
|
model_type: str = "auto", |
|
|
): |
|
|
""" |
|
|
Initialize model. |
|
|
|
|
|
Parameters |
|
|
---------- |
|
|
weights_path : str or Path |
|
|
Path to model weights (.pt file) |
|
|
config_path : str or Path, optional |
|
|
Path to config.json with class names. If None, auto-detects from weights filename. |
|
|
model_type : str, default="auto" |
|
|
Model type: "docstructbench", "doclaynet", or "auto" (detect from filename) |
|
|
""" |
|
|
self.weights_path = Path(weights_path) |
|
|
self._model = None |
|
|
|
|
|
|
|
|
if config_path: |
|
|
with open(config_path) as f: |
|
|
config = json.load(f) |
|
|
self.class_names = {i: name for i, name in enumerate(config["class_names"])} |
|
|
else: |
|
|
self.class_names = self._get_class_names(model_type) |
|
|
|
|
|
def _get_class_names(self, model_type: str) -> Dict[int, str]: |
|
|
"""Get class names based on model type.""" |
|
|
if model_type == "auto": |
|
|
name = self.weights_path.stem.lower() |
|
|
if "doclaynet" in name: |
|
|
return self.DOCLAYNET_CLASSES |
|
|
return self.DOCSTRUCTBENCH_CLASSES |
|
|
elif model_type == "doclaynet": |
|
|
return self.DOCLAYNET_CLASSES |
|
|
elif model_type == "docstructbench": |
|
|
return self.DOCSTRUCTBENCH_CLASSES |
|
|
else: |
|
|
raise ValueError(f"Unknown model type: {model_type}") |
|
|
|
|
|
@property |
|
|
def model(self) -> YOLO: |
|
|
"""Lazy-load the YOLO model.""" |
|
|
if self._model is None: |
|
|
self._model = YOLO(str(self.weights_path)) |
|
|
return self._model |
|
|
|
|
|
def predict( |
|
|
self, |
|
|
source: Union[str, Path, Image.Image, np.ndarray], |
|
|
confidence: float = 0.2, |
|
|
image_size: int = 1024, |
|
|
device: str = "cpu", |
|
|
) -> List[Dict]: |
|
|
""" |
|
|
Run inference on an image. |
|
|
|
|
|
Parameters |
|
|
---------- |
|
|
source : str, Path, PIL.Image, or np.ndarray |
|
|
Input image |
|
|
confidence : float, default=0.2 |
|
|
Confidence threshold |
|
|
image_size : int, default=1024 |
|
|
Input image size |
|
|
device : str, default="cpu" |
|
|
Device to run on ("cpu", "cuda", "mps") |
|
|
|
|
|
Returns |
|
|
------- |
|
|
List[Dict] |
|
|
List of detections, each with keys: |
|
|
- class_id: int |
|
|
- class_name: str |
|
|
- confidence: float |
|
|
- bbox: [x1, y1, x2, y2] |
|
|
""" |
|
|
results = self.model.predict( |
|
|
source=str(source) if isinstance(source, Path) else source, |
|
|
imgsz=image_size, |
|
|
conf=confidence, |
|
|
device=device, |
|
|
save=False, |
|
|
verbose=False, |
|
|
) |
|
|
|
|
|
detections = [] |
|
|
for result in results: |
|
|
for box in result.boxes: |
|
|
cls = int(box.cls[0]) |
|
|
detections.append( |
|
|
{ |
|
|
"class_id": cls, |
|
|
"class_name": self.class_names.get(cls, f"class_{cls}"), |
|
|
"confidence": float(box.conf[0]), |
|
|
"bbox": box.xyxy[0].tolist(), |
|
|
} |
|
|
) |
|
|
|
|
|
return detections |
|
|
|