|
|
""" |
|
|
DocLayout-YOLO inference module. |
|
|
|
|
|
This module patches ultralytics to support G2L_CRM custom layers, |
|
|
enabling loading of DocLayout-YOLO model weights. |
|
|
|
|
|
Usage: |
|
|
from doclayout_yolo import DocLayoutModel |
|
|
|
|
|
model = DocLayoutModel("model.pt") |
|
|
results = model.predict("document.png") |
|
|
""" |
|
|
|
|
|
import sys |
|
|
|
|
|
|
|
|
from .g2l_crm import G2L_CRM, DilatedBlock, DilatedBottleneck |
|
|
|
|
|
|
|
|
def _patch_ultralytics(): |
|
|
""" |
|
|
Patch ultralytics to recognize G2L_CRM module. |
|
|
|
|
|
This must be called before loading any models that use G2L_CRM. |
|
|
""" |
|
|
import sys |
|
|
from types import ModuleType |
|
|
|
|
|
try: |
|
|
import ultralytics.nn.modules as modules |
|
|
import ultralytics.nn.tasks as tasks |
|
|
except ImportError: |
|
|
raise ImportError( |
|
|
"ultralytics is required. Install with: pip install ultralytics" |
|
|
) |
|
|
|
|
|
|
|
|
if hasattr(modules, "G2L_CRM"): |
|
|
return |
|
|
|
|
|
|
|
|
modules.__dict__["G2L_CRM"] = G2L_CRM |
|
|
modules.__dict__["DilatedBlock"] = DilatedBlock |
|
|
modules.__dict__["DilatedBottleneck"] = DilatedBottleneck |
|
|
|
|
|
|
|
|
|
|
|
fake_module = ModuleType("ultralytics.nn.modules.g2l_crm") |
|
|
fake_module.G2L_CRM = G2L_CRM |
|
|
fake_module.DilatedBlock = DilatedBlock |
|
|
fake_module.DilatedBottleneck = DilatedBottleneck |
|
|
sys.modules["ultralytics.nn.modules.g2l_crm"] = fake_module |
|
|
|
|
|
|
|
|
tasks.__dict__["G2L_CRM"] = G2L_CRM |
|
|
|
|
|
|
|
|
_patch_parse_model(tasks) |
|
|
|
|
|
|
|
|
def _patch_parse_model(tasks): |
|
|
"""Patch parse_model to include G2L_CRM in the module sets.""" |
|
|
import functools |
|
|
|
|
|
original_parse_model = tasks.parse_model |
|
|
|
|
|
@functools.wraps(original_parse_model) |
|
|
def patched_parse_model(d, ch, verbose=True): |
|
|
|
|
|
|
|
|
old_globals = original_parse_model.__globals__.copy() |
|
|
original_parse_model.__globals__["G2L_CRM"] = G2L_CRM |
|
|
|
|
|
try: |
|
|
result = original_parse_model(d, ch, verbose) |
|
|
finally: |
|
|
|
|
|
if "G2L_CRM" not in old_globals: |
|
|
original_parse_model.__globals__.pop("G2L_CRM", None) |
|
|
|
|
|
return result |
|
|
|
|
|
tasks.parse_model = patched_parse_model |
|
|
|
|
|
|
|
|
|
|
|
_patch_ultralytics() |
|
|
|
|
|
|
|
|
from .model import DocLayoutModel |
|
|
|
|
|
__all__ = ["DocLayoutModel", "G2L_CRM"] |
|
|
__version__ = "0.1.0" |
|
|
|