""" DocLayout-YOLO inference module. This module patches ultralytics to support G2L_CRM custom layers, enabling loading of DocLayout-YOLO model weights. Usage: from doclayout_yolo import DocLayoutModel model = DocLayoutModel("model.pt") results = model.predict("document.png") """ import sys # Import our custom modules first from .g2l_crm import G2L_CRM, DilatedBlock, DilatedBottleneck def _patch_ultralytics(): """ Patch ultralytics to recognize G2L_CRM module. This must be called before loading any models that use G2L_CRM. """ import sys from types import ModuleType try: import ultralytics.nn.modules as modules import ultralytics.nn.tasks as tasks except ImportError: raise ImportError( "ultralytics is required. Install with: pip install ultralytics" ) # Check if already patched (fork is installed) if hasattr(modules, "G2L_CRM"): return # Already has G2L_CRM, no patching needed # Inject G2L_CRM into ultralytics.nn.modules modules.__dict__["G2L_CRM"] = G2L_CRM modules.__dict__["DilatedBlock"] = DilatedBlock modules.__dict__["DilatedBottleneck"] = DilatedBottleneck # Create fake ultralytics.nn.modules.g2l_crm module for PyTorch unpickling # The weights file references this path, so we need it to exist fake_module = ModuleType("ultralytics.nn.modules.g2l_crm") fake_module.G2L_CRM = G2L_CRM fake_module.DilatedBlock = DilatedBlock fake_module.DilatedBottleneck = DilatedBottleneck sys.modules["ultralytics.nn.modules.g2l_crm"] = fake_module # Also need to inject into tasks module for globals() lookup tasks.__dict__["G2L_CRM"] = G2L_CRM # Monkey-patch parse_model to include G2L_CRM in base_modules and repeat_modules _patch_parse_model(tasks) def _patch_parse_model(tasks): """Patch parse_model to include G2L_CRM in the module sets.""" import functools original_parse_model = tasks.parse_model @functools.wraps(original_parse_model) def patched_parse_model(d, ch, verbose=True): # Temporarily inject G2L_CRM into the function's globals # This is a hack, but necessary because parse_model uses globals()[m] old_globals = original_parse_model.__globals__.copy() original_parse_model.__globals__["G2L_CRM"] = G2L_CRM try: result = original_parse_model(d, ch, verbose) finally: # Restore original globals if "G2L_CRM" not in old_globals: original_parse_model.__globals__.pop("G2L_CRM", None) return result tasks.parse_model = patched_parse_model # Apply patch on import _patch_ultralytics() # Public API from .model import DocLayoutModel __all__ = ["DocLayoutModel", "G2L_CRM"] __version__ = "0.1.0"