andrew33333's picture
Upload folder using huggingface_hub
c147abc verified
"""
DocLayout-YOLO inference module.
This module patches ultralytics to support G2L_CRM custom layers,
enabling loading of DocLayout-YOLO model weights.
Usage:
from doclayout_yolo import DocLayoutModel
model = DocLayoutModel("model.pt")
results = model.predict("document.png")
"""
import sys
# Import our custom modules first
from .g2l_crm import G2L_CRM, DilatedBlock, DilatedBottleneck
def _patch_ultralytics():
"""
Patch ultralytics to recognize G2L_CRM module.
This must be called before loading any models that use G2L_CRM.
"""
import sys
from types import ModuleType
try:
import ultralytics.nn.modules as modules
import ultralytics.nn.tasks as tasks
except ImportError:
raise ImportError(
"ultralytics is required. Install with: pip install ultralytics"
)
# Check if already patched (fork is installed)
if hasattr(modules, "G2L_CRM"):
return # Already has G2L_CRM, no patching needed
# Inject G2L_CRM into ultralytics.nn.modules
modules.__dict__["G2L_CRM"] = G2L_CRM
modules.__dict__["DilatedBlock"] = DilatedBlock
modules.__dict__["DilatedBottleneck"] = DilatedBottleneck
# Create fake ultralytics.nn.modules.g2l_crm module for PyTorch unpickling
# The weights file references this path, so we need it to exist
fake_module = ModuleType("ultralytics.nn.modules.g2l_crm")
fake_module.G2L_CRM = G2L_CRM
fake_module.DilatedBlock = DilatedBlock
fake_module.DilatedBottleneck = DilatedBottleneck
sys.modules["ultralytics.nn.modules.g2l_crm"] = fake_module
# Also need to inject into tasks module for globals() lookup
tasks.__dict__["G2L_CRM"] = G2L_CRM
# Monkey-patch parse_model to include G2L_CRM in base_modules and repeat_modules
_patch_parse_model(tasks)
def _patch_parse_model(tasks):
"""Patch parse_model to include G2L_CRM in the module sets."""
import functools
original_parse_model = tasks.parse_model
@functools.wraps(original_parse_model)
def patched_parse_model(d, ch, verbose=True):
# Temporarily inject G2L_CRM into the function's globals
# This is a hack, but necessary because parse_model uses globals()[m]
old_globals = original_parse_model.__globals__.copy()
original_parse_model.__globals__["G2L_CRM"] = G2L_CRM
try:
result = original_parse_model(d, ch, verbose)
finally:
# Restore original globals
if "G2L_CRM" not in old_globals:
original_parse_model.__globals__.pop("G2L_CRM", None)
return result
tasks.parse_model = patched_parse_model
# Apply patch on import
_patch_ultralytics()
# Public API
from .model import DocLayoutModel
__all__ = ["DocLayoutModel", "G2L_CRM"]
__version__ = "0.1.0"