Upload folder using huggingface_hub
Browse files- README.md +98 -0
- config.json +22 -0
- doclayout_yolo/__init__.py +93 -0
- doclayout_yolo/__pycache__/__init__.cpython-312.pyc +0 -0
- doclayout_yolo/__pycache__/__init__.cpython-313.pyc +0 -0
- doclayout_yolo/__pycache__/g2l_crm.cpython-312.pyc +0 -0
- doclayout_yolo/__pycache__/model.cpython-312.pyc +0 -0
- doclayout_yolo/g2l_crm.py +118 -0
- doclayout_yolo/model.py +154 -0
- model.onnx +3 -0
- model.pt +3 -0
README.md
ADDED
|
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
license: apache-2.0
|
| 3 |
+
tags:
|
| 4 |
+
- document-layout
|
| 5 |
+
- object-detection
|
| 6 |
+
- yolo
|
| 7 |
+
- document-analysis
|
| 8 |
+
library_name: ultralytics
|
| 9 |
+
---
|
| 10 |
+
|
| 11 |
+
# DocLayout-YOLO - Docstructbench
|
| 12 |
+
|
| 13 |
+
Document layout detection model based on [DocLayout-YOLO](https://github.com/opendatalab/DocLayout-YOLO).
|
| 14 |
+
|
| 15 |
+
## Model Description
|
| 16 |
+
|
| 17 |
+
- **Architecture**: YOLOv10m with G2L_CRM (Global-to-Local Context Refining Module)
|
| 18 |
+
- **Classes**: 10 document layout elements
|
| 19 |
+
- **Input Size**: 1024x1024
|
| 20 |
+
- **Paper**: [DocLayout-YOLO](https://arxiv.org/abs/2410.12628)
|
| 21 |
+
|
| 22 |
+
### Classes
|
| 23 |
+
|
| 24 |
+
- `title`
|
| 25 |
+
- `plain_text`
|
| 26 |
+
- `abandon`
|
| 27 |
+
- `figure`
|
| 28 |
+
- `figure_caption`
|
| 29 |
+
- `table`
|
| 30 |
+
- `table_caption`
|
| 31 |
+
- `table_footnote`
|
| 32 |
+
- `isolate_formula`
|
| 33 |
+
- `formula_caption`
|
| 34 |
+
|
| 35 |
+
## Usage
|
| 36 |
+
|
| 37 |
+
### PyTorch
|
| 38 |
+
|
| 39 |
+
```python
|
| 40 |
+
from huggingface_hub import snapshot_download
|
| 41 |
+
import sys
|
| 42 |
+
|
| 43 |
+
# Download model (includes code + weights)
|
| 44 |
+
repo_path = snapshot_download("anyformat-ai/doclayout-yolo-docstructbench")
|
| 45 |
+
|
| 46 |
+
# Import and use
|
| 47 |
+
sys.path.insert(0, repo_path)
|
| 48 |
+
from doclayout_yolo import DocLayoutModel
|
| 49 |
+
|
| 50 |
+
model = DocLayoutModel(f"{repo_path}/model.pt")
|
| 51 |
+
results = model.predict("document.png")
|
| 52 |
+
|
| 53 |
+
for det in results:
|
| 54 |
+
print(f"{det['class_name']}: {det['confidence']:.2f} at {det['bbox']}")
|
| 55 |
+
```
|
| 56 |
+
|
| 57 |
+
### ONNX
|
| 58 |
+
|
| 59 |
+
```python
|
| 60 |
+
import onnxruntime as ort
|
| 61 |
+
import numpy as np
|
| 62 |
+
from huggingface_hub import hf_hub_download
|
| 63 |
+
import json
|
| 64 |
+
|
| 65 |
+
# Download ONNX model and config
|
| 66 |
+
model_path = hf_hub_download("anyformat-ai/doclayout-yolo-docstructbench", "model.onnx")
|
| 67 |
+
config_path = hf_hub_download("anyformat-ai/doclayout-yolo-docstructbench", "config.json")
|
| 68 |
+
|
| 69 |
+
with open(config_path) as f:
|
| 70 |
+
config = json.load(f)
|
| 71 |
+
|
| 72 |
+
session = ort.InferenceSession(model_path)
|
| 73 |
+
# Preprocess image to (1, 3, 1024, 1024) float32, normalized to [0, 1]
|
| 74 |
+
# Run inference and post-process outputs
|
| 75 |
+
```
|
| 76 |
+
|
| 77 |
+
## Requirements
|
| 78 |
+
|
| 79 |
+
```
|
| 80 |
+
ultralytics
|
| 81 |
+
huggingface-hub
|
| 82 |
+
onnxruntime # for ONNX inference
|
| 83 |
+
```
|
| 84 |
+
|
| 85 |
+
## Citation
|
| 86 |
+
|
| 87 |
+
```bibtex
|
| 88 |
+
@article{zhao2024doclayout,
|
| 89 |
+
title={DocLayout-YOLO: Enhancing Document Layout Analysis through Diverse Synthetic Data and Global-to-Local Adaptive Perception},
|
| 90 |
+
author={Zhao, Zhiyuan and Kang, Hengrui and Wang, Bin and He, Conghui},
|
| 91 |
+
journal={arXiv preprint arXiv:2410.12628},
|
| 92 |
+
year={2024}
|
| 93 |
+
}
|
| 94 |
+
```
|
| 95 |
+
|
| 96 |
+
## License
|
| 97 |
+
|
| 98 |
+
Apache 2.0
|
config.json
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"model_type": "doclayout-yolo",
|
| 3 |
+
"model_name": "docstructbench",
|
| 4 |
+
"architecture": "yolov10m-g2l-crm",
|
| 5 |
+
"num_classes": 10,
|
| 6 |
+
"class_names": [
|
| 7 |
+
"title",
|
| 8 |
+
"plain_text",
|
| 9 |
+
"abandon",
|
| 10 |
+
"figure",
|
| 11 |
+
"figure_caption",
|
| 12 |
+
"table",
|
| 13 |
+
"table_caption",
|
| 14 |
+
"table_footnote",
|
| 15 |
+
"isolate_formula",
|
| 16 |
+
"formula_caption"
|
| 17 |
+
],
|
| 18 |
+
"input_size": 1024,
|
| 19 |
+
"description": "Document layout detection for financial documents",
|
| 20 |
+
"source": "https://github.com/opendatalab/DocLayout-YOLO",
|
| 21 |
+
"paper": "https://arxiv.org/abs/2410.12628"
|
| 22 |
+
}
|
doclayout_yolo/__init__.py
ADDED
|
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
DocLayout-YOLO inference module.
|
| 3 |
+
|
| 4 |
+
This module patches ultralytics to support G2L_CRM custom layers,
|
| 5 |
+
enabling loading of DocLayout-YOLO model weights.
|
| 6 |
+
|
| 7 |
+
Usage:
|
| 8 |
+
from doclayout_yolo import DocLayoutModel
|
| 9 |
+
|
| 10 |
+
model = DocLayoutModel("model.pt")
|
| 11 |
+
results = model.predict("document.png")
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
import sys
|
| 15 |
+
|
| 16 |
+
# Import our custom modules first
|
| 17 |
+
from .g2l_crm import G2L_CRM, DilatedBlock, DilatedBottleneck
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def _patch_ultralytics():
|
| 21 |
+
"""
|
| 22 |
+
Patch ultralytics to recognize G2L_CRM module.
|
| 23 |
+
|
| 24 |
+
This must be called before loading any models that use G2L_CRM.
|
| 25 |
+
"""
|
| 26 |
+
import sys
|
| 27 |
+
from types import ModuleType
|
| 28 |
+
|
| 29 |
+
try:
|
| 30 |
+
import ultralytics.nn.modules as modules
|
| 31 |
+
import ultralytics.nn.tasks as tasks
|
| 32 |
+
except ImportError:
|
| 33 |
+
raise ImportError(
|
| 34 |
+
"ultralytics is required. Install with: pip install ultralytics"
|
| 35 |
+
)
|
| 36 |
+
|
| 37 |
+
# Check if already patched (fork is installed)
|
| 38 |
+
if hasattr(modules, "G2L_CRM"):
|
| 39 |
+
return # Already has G2L_CRM, no patching needed
|
| 40 |
+
|
| 41 |
+
# Inject G2L_CRM into ultralytics.nn.modules
|
| 42 |
+
modules.__dict__["G2L_CRM"] = G2L_CRM
|
| 43 |
+
modules.__dict__["DilatedBlock"] = DilatedBlock
|
| 44 |
+
modules.__dict__["DilatedBottleneck"] = DilatedBottleneck
|
| 45 |
+
|
| 46 |
+
# Create fake ultralytics.nn.modules.g2l_crm module for PyTorch unpickling
|
| 47 |
+
# The weights file references this path, so we need it to exist
|
| 48 |
+
fake_module = ModuleType("ultralytics.nn.modules.g2l_crm")
|
| 49 |
+
fake_module.G2L_CRM = G2L_CRM
|
| 50 |
+
fake_module.DilatedBlock = DilatedBlock
|
| 51 |
+
fake_module.DilatedBottleneck = DilatedBottleneck
|
| 52 |
+
sys.modules["ultralytics.nn.modules.g2l_crm"] = fake_module
|
| 53 |
+
|
| 54 |
+
# Also need to inject into tasks module for globals() lookup
|
| 55 |
+
tasks.__dict__["G2L_CRM"] = G2L_CRM
|
| 56 |
+
|
| 57 |
+
# Monkey-patch parse_model to include G2L_CRM in base_modules and repeat_modules
|
| 58 |
+
_patch_parse_model(tasks)
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def _patch_parse_model(tasks):
|
| 62 |
+
"""Patch parse_model to include G2L_CRM in the module sets."""
|
| 63 |
+
import functools
|
| 64 |
+
|
| 65 |
+
original_parse_model = tasks.parse_model
|
| 66 |
+
|
| 67 |
+
@functools.wraps(original_parse_model)
|
| 68 |
+
def patched_parse_model(d, ch, verbose=True):
|
| 69 |
+
# Temporarily inject G2L_CRM into the function's globals
|
| 70 |
+
# This is a hack, but necessary because parse_model uses globals()[m]
|
| 71 |
+
old_globals = original_parse_model.__globals__.copy()
|
| 72 |
+
original_parse_model.__globals__["G2L_CRM"] = G2L_CRM
|
| 73 |
+
|
| 74 |
+
try:
|
| 75 |
+
result = original_parse_model(d, ch, verbose)
|
| 76 |
+
finally:
|
| 77 |
+
# Restore original globals
|
| 78 |
+
if "G2L_CRM" not in old_globals:
|
| 79 |
+
original_parse_model.__globals__.pop("G2L_CRM", None)
|
| 80 |
+
|
| 81 |
+
return result
|
| 82 |
+
|
| 83 |
+
tasks.parse_model = patched_parse_model
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
# Apply patch on import
|
| 87 |
+
_patch_ultralytics()
|
| 88 |
+
|
| 89 |
+
# Public API
|
| 90 |
+
from .model import DocLayoutModel
|
| 91 |
+
|
| 92 |
+
__all__ = ["DocLayoutModel", "G2L_CRM"]
|
| 93 |
+
__version__ = "0.1.0"
|
doclayout_yolo/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (2.94 kB). View file
|
|
|
doclayout_yolo/__pycache__/__init__.cpython-313.pyc
ADDED
|
Binary file (1.21 kB). View file
|
|
|
doclayout_yolo/__pycache__/g2l_crm.cpython-312.pyc
ADDED
|
Binary file (8.73 kB). View file
|
|
|
doclayout_yolo/__pycache__/model.cpython-312.pyc
ADDED
|
Binary file (5.65 kB). View file
|
|
|
doclayout_yolo/g2l_crm.py
ADDED
|
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
G2L_CRM (Global-to-Local Context Refining Module) for document layout analysis.
|
| 3 |
+
|
| 4 |
+
Based on DocLayout-YOLO: https://github.com/opendatalab/DocLayout-YOLO
|
| 5 |
+
Paper: https://arxiv.org/abs/2410.12628
|
| 6 |
+
Original Authors: Zhiyuan Zhao, Hengrui Kang, Bin Wang, Conghui He
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
from torch import nn
|
| 11 |
+
import torch.nn.functional as F
|
| 12 |
+
|
| 13 |
+
from ultralytics.nn.modules.conv import Conv
|
| 14 |
+
from ultralytics.nn.modules.block import CIB
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class DilatedBlock(nn.Module):
|
| 18 |
+
"""Dilated convolution block with multi-scale fusion."""
|
| 19 |
+
|
| 20 |
+
def __init__(self, c, dilation, k, fuse="sum", shortcut=True):
|
| 21 |
+
super().__init__()
|
| 22 |
+
self.dilation = dilation
|
| 23 |
+
self.k = k
|
| 24 |
+
self.cv2 = Conv(c, c, k=1, s=1)
|
| 25 |
+
self.add = shortcut
|
| 26 |
+
|
| 27 |
+
self.fuse = fuse
|
| 28 |
+
if fuse == "glu":
|
| 29 |
+
self.conv_gating = Conv(
|
| 30 |
+
c * len(self.dilation), c * len(self.dilation), k=1, s=1, g=c * len(self.dilation)
|
| 31 |
+
)
|
| 32 |
+
self.conv1x1 = Conv(c * len(self.dilation), c, k=1, s=1, g=c)
|
| 33 |
+
elif fuse == "sum":
|
| 34 |
+
self.conv1x1 = Conv(c, c, k=1, s=1, g=c)
|
| 35 |
+
|
| 36 |
+
self.dcv = Conv(c, c, k=self.k, s=1)
|
| 37 |
+
|
| 38 |
+
def dilated_conv(self, x, dilation):
|
| 39 |
+
"""Apply dilated convolution, handling both fused and non-fused cases."""
|
| 40 |
+
act = self.dcv.act
|
| 41 |
+
padding = dilation * (self.k // 2)
|
| 42 |
+
|
| 43 |
+
if hasattr(self.dcv, "bn") and self.dcv.bn is not None:
|
| 44 |
+
bn = self.dcv.bn
|
| 45 |
+
weight = self.dcv.conv.weight
|
| 46 |
+
return act(bn(F.conv2d(x, weight, stride=1, padding=padding, dilation=dilation)))
|
| 47 |
+
else:
|
| 48 |
+
weight = self.dcv.conv.weight
|
| 49 |
+
bias = self.dcv.conv.bias if hasattr(self.dcv.conv, "bias") else None
|
| 50 |
+
return act(F.conv2d(x, weight, bias=bias, stride=1, padding=padding, dilation=dilation))
|
| 51 |
+
|
| 52 |
+
def forward(self, x):
|
| 53 |
+
dx = [self.dilated_conv(x, d) for d in self.dilation]
|
| 54 |
+
dx = [self.cv2(_dx) for _dx in dx]
|
| 55 |
+
if self.fuse == "glu":
|
| 56 |
+
dx = torch.cat(dx, dim=1)
|
| 57 |
+
G = torch.sigmoid(self.conv_gating(dx))
|
| 58 |
+
dx = dx * G
|
| 59 |
+
dx = self.conv1x1(dx)
|
| 60 |
+
elif self.fuse == "sum":
|
| 61 |
+
dx = [_dx.unsqueeze(0) for _dx in dx]
|
| 62 |
+
dx = torch.cat(dx, dim=0)
|
| 63 |
+
dx = torch.sum(dx, dim=0)
|
| 64 |
+
dx = self.conv1x1(dx)
|
| 65 |
+
|
| 66 |
+
return x + dx if self.add else dx
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
class DilatedBottleneck(nn.Module):
|
| 70 |
+
"""Bottleneck with dilated convolution."""
|
| 71 |
+
|
| 72 |
+
def __init__(self, c1, c2, shortcut=True, dilation=[1, 2, 3], block_k=3, fuse="sum", g=1, k=(3, 3), e=0.5):
|
| 73 |
+
super().__init__()
|
| 74 |
+
c_ = int(c2 * e)
|
| 75 |
+
self.cv1 = Conv(c1, c_, k[0], 1)
|
| 76 |
+
self.cv2 = Conv(c_, c2, k[1], 1, g=g)
|
| 77 |
+
self.dilated_block = DilatedBlock(c_, dilation, block_k, fuse)
|
| 78 |
+
self.add = shortcut and c1 == c2
|
| 79 |
+
|
| 80 |
+
def forward(self, x):
|
| 81 |
+
return x + self.cv2(self.dilated_block(self.cv1(x))) if self.add else self.cv2(self.dilated_block(self.cv1(x)))
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
class G2L_CRM(nn.Module):
|
| 85 |
+
"""
|
| 86 |
+
Global-to-Local Context Refining Module.
|
| 87 |
+
|
| 88 |
+
CSP Bottleneck with optional dilated convolutions for multi-scale
|
| 89 |
+
feature extraction in document layout analysis.
|
| 90 |
+
"""
|
| 91 |
+
|
| 92 |
+
def __init__(
|
| 93 |
+
self, c1, c2, n=1, shortcut=False, use_dilated=False, dilation=[1, 2, 3], block_k=3, fuse="sum", g=1, e=0.5
|
| 94 |
+
):
|
| 95 |
+
super().__init__()
|
| 96 |
+
self.c = int(c2 * e)
|
| 97 |
+
self.cv1 = Conv(c1, 2 * self.c, 1, 1)
|
| 98 |
+
self.cv2 = Conv((2 + n) * self.c, c2, 1)
|
| 99 |
+
|
| 100 |
+
if use_dilated:
|
| 101 |
+
self.m = nn.ModuleList(
|
| 102 |
+
DilatedBottleneck(self.c, self.c, shortcut, dilation, block_k, fuse, g, k=((3, 3), (3, 3)), e=1.0)
|
| 103 |
+
for _ in range(n)
|
| 104 |
+
)
|
| 105 |
+
else:
|
| 106 |
+
self.m = nn.ModuleList(CIB(self.c, self.c, shortcut, e=1.0) for _ in range(n))
|
| 107 |
+
|
| 108 |
+
def forward(self, x):
|
| 109 |
+
y = list(self.cv1(x).chunk(2, 1))
|
| 110 |
+
for m in self.m:
|
| 111 |
+
y.append(m(y[-1]))
|
| 112 |
+
return self.cv2(torch.cat(y, 1))
|
| 113 |
+
|
| 114 |
+
def forward_split(self, x):
|
| 115 |
+
y = list(self.cv1(x).split((self.c, self.c), 1))
|
| 116 |
+
for m in self.m:
|
| 117 |
+
y.append(m(y[-1]))
|
| 118 |
+
return self.cv2(torch.cat(y, 1))
|
doclayout_yolo/model.py
ADDED
|
@@ -0,0 +1,154 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Simple DocLayout model for inference."""
|
| 2 |
+
|
| 3 |
+
import json
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
from typing import Dict, List, Union
|
| 6 |
+
|
| 7 |
+
import numpy as np
|
| 8 |
+
from PIL import Image
|
| 9 |
+
from ultralytics import YOLO
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class DocLayoutModel:
|
| 13 |
+
"""
|
| 14 |
+
Document layout detection model.
|
| 15 |
+
|
| 16 |
+
Examples
|
| 17 |
+
--------
|
| 18 |
+
>>> model = DocLayoutModel("model.pt")
|
| 19 |
+
>>> results = model.predict("document.png")
|
| 20 |
+
>>> for det in results:
|
| 21 |
+
... print(f"{det['class_name']}: {det['confidence']:.2f}")
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
# Default class mappings
|
| 25 |
+
DOCSTRUCTBENCH_CLASSES = {
|
| 26 |
+
0: "title",
|
| 27 |
+
1: "plain_text",
|
| 28 |
+
2: "abandon",
|
| 29 |
+
3: "figure",
|
| 30 |
+
4: "figure_caption",
|
| 31 |
+
5: "table",
|
| 32 |
+
6: "table_caption",
|
| 33 |
+
7: "table_footnote",
|
| 34 |
+
8: "isolate_formula",
|
| 35 |
+
9: "formula_caption",
|
| 36 |
+
}
|
| 37 |
+
|
| 38 |
+
DOCLAYNET_CLASSES = {
|
| 39 |
+
0: "Caption",
|
| 40 |
+
1: "Footnote",
|
| 41 |
+
2: "Formula",
|
| 42 |
+
3: "List-item",
|
| 43 |
+
4: "Page-footer",
|
| 44 |
+
5: "Page-header",
|
| 45 |
+
6: "Picture",
|
| 46 |
+
7: "Section-header",
|
| 47 |
+
8: "Table",
|
| 48 |
+
9: "Text",
|
| 49 |
+
10: "Title",
|
| 50 |
+
}
|
| 51 |
+
|
| 52 |
+
def __init__(
|
| 53 |
+
self,
|
| 54 |
+
weights_path: Union[str, Path],
|
| 55 |
+
config_path: Union[str, Path, None] = None,
|
| 56 |
+
model_type: str = "auto",
|
| 57 |
+
):
|
| 58 |
+
"""
|
| 59 |
+
Initialize model.
|
| 60 |
+
|
| 61 |
+
Parameters
|
| 62 |
+
----------
|
| 63 |
+
weights_path : str or Path
|
| 64 |
+
Path to model weights (.pt file)
|
| 65 |
+
config_path : str or Path, optional
|
| 66 |
+
Path to config.json with class names. If None, auto-detects from weights filename.
|
| 67 |
+
model_type : str, default="auto"
|
| 68 |
+
Model type: "docstructbench", "doclaynet", or "auto" (detect from filename)
|
| 69 |
+
"""
|
| 70 |
+
self.weights_path = Path(weights_path)
|
| 71 |
+
self._model = None
|
| 72 |
+
|
| 73 |
+
# Load class names from config or auto-detect
|
| 74 |
+
if config_path:
|
| 75 |
+
with open(config_path) as f:
|
| 76 |
+
config = json.load(f)
|
| 77 |
+
self.class_names = {i: name for i, name in enumerate(config["class_names"])}
|
| 78 |
+
else:
|
| 79 |
+
self.class_names = self._get_class_names(model_type)
|
| 80 |
+
|
| 81 |
+
def _get_class_names(self, model_type: str) -> Dict[int, str]:
|
| 82 |
+
"""Get class names based on model type."""
|
| 83 |
+
if model_type == "auto":
|
| 84 |
+
name = self.weights_path.stem.lower()
|
| 85 |
+
if "doclaynet" in name:
|
| 86 |
+
return self.DOCLAYNET_CLASSES
|
| 87 |
+
return self.DOCSTRUCTBENCH_CLASSES
|
| 88 |
+
elif model_type == "doclaynet":
|
| 89 |
+
return self.DOCLAYNET_CLASSES
|
| 90 |
+
elif model_type == "docstructbench":
|
| 91 |
+
return self.DOCSTRUCTBENCH_CLASSES
|
| 92 |
+
else:
|
| 93 |
+
raise ValueError(f"Unknown model type: {model_type}")
|
| 94 |
+
|
| 95 |
+
@property
|
| 96 |
+
def model(self) -> YOLO:
|
| 97 |
+
"""Lazy-load the YOLO model."""
|
| 98 |
+
if self._model is None:
|
| 99 |
+
self._model = YOLO(str(self.weights_path))
|
| 100 |
+
return self._model
|
| 101 |
+
|
| 102 |
+
def predict(
|
| 103 |
+
self,
|
| 104 |
+
source: Union[str, Path, Image.Image, np.ndarray],
|
| 105 |
+
confidence: float = 0.2,
|
| 106 |
+
image_size: int = 1024,
|
| 107 |
+
device: str = "cpu",
|
| 108 |
+
) -> List[Dict]:
|
| 109 |
+
"""
|
| 110 |
+
Run inference on an image.
|
| 111 |
+
|
| 112 |
+
Parameters
|
| 113 |
+
----------
|
| 114 |
+
source : str, Path, PIL.Image, or np.ndarray
|
| 115 |
+
Input image
|
| 116 |
+
confidence : float, default=0.2
|
| 117 |
+
Confidence threshold
|
| 118 |
+
image_size : int, default=1024
|
| 119 |
+
Input image size
|
| 120 |
+
device : str, default="cpu"
|
| 121 |
+
Device to run on ("cpu", "cuda", "mps")
|
| 122 |
+
|
| 123 |
+
Returns
|
| 124 |
+
-------
|
| 125 |
+
List[Dict]
|
| 126 |
+
List of detections, each with keys:
|
| 127 |
+
- class_id: int
|
| 128 |
+
- class_name: str
|
| 129 |
+
- confidence: float
|
| 130 |
+
- bbox: [x1, y1, x2, y2]
|
| 131 |
+
"""
|
| 132 |
+
results = self.model.predict(
|
| 133 |
+
source=str(source) if isinstance(source, Path) else source,
|
| 134 |
+
imgsz=image_size,
|
| 135 |
+
conf=confidence,
|
| 136 |
+
device=device,
|
| 137 |
+
save=False,
|
| 138 |
+
verbose=False,
|
| 139 |
+
)
|
| 140 |
+
|
| 141 |
+
detections = []
|
| 142 |
+
for result in results:
|
| 143 |
+
for box in result.boxes:
|
| 144 |
+
cls = int(box.cls[0])
|
| 145 |
+
detections.append(
|
| 146 |
+
{
|
| 147 |
+
"class_id": cls,
|
| 148 |
+
"class_name": self.class_names.get(cls, f"class_{cls}"),
|
| 149 |
+
"confidence": float(box.conf[0]),
|
| 150 |
+
"bbox": box.xyxy[0].tolist(),
|
| 151 |
+
}
|
| 152 |
+
)
|
| 153 |
+
|
| 154 |
+
return detections
|
model.onnx
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:0142c1154f5f4fcb5eb14d5f29d9cebfaee96433b6d2a99c36bb07779cd7a388
|
| 3 |
+
size 75823701
|
model.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:f1457fe54bb1dedc4b1d1b7b07348288ab63c730569343f3e7a8194e69d39266
|
| 3 |
+
size 40597687
|