File size: 2,837 Bytes
c147abc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
"""
DocLayout-YOLO inference module.

This module patches ultralytics to support G2L_CRM custom layers,
enabling loading of DocLayout-YOLO model weights.

Usage:
    from doclayout_yolo import DocLayoutModel

    model = DocLayoutModel("model.pt")
    results = model.predict("document.png")
"""

import sys

# Import our custom modules first
from .g2l_crm import G2L_CRM, DilatedBlock, DilatedBottleneck


def _patch_ultralytics():
    """
    Patch ultralytics to recognize G2L_CRM module.

    This must be called before loading any models that use G2L_CRM.
    """
    import sys
    from types import ModuleType

    try:
        import ultralytics.nn.modules as modules
        import ultralytics.nn.tasks as tasks
    except ImportError:
        raise ImportError(
            "ultralytics is required. Install with: pip install ultralytics"
        )

    # Check if already patched (fork is installed)
    if hasattr(modules, "G2L_CRM"):
        return  # Already has G2L_CRM, no patching needed

    # Inject G2L_CRM into ultralytics.nn.modules
    modules.__dict__["G2L_CRM"] = G2L_CRM
    modules.__dict__["DilatedBlock"] = DilatedBlock
    modules.__dict__["DilatedBottleneck"] = DilatedBottleneck

    # Create fake ultralytics.nn.modules.g2l_crm module for PyTorch unpickling
    # The weights file references this path, so we need it to exist
    fake_module = ModuleType("ultralytics.nn.modules.g2l_crm")
    fake_module.G2L_CRM = G2L_CRM
    fake_module.DilatedBlock = DilatedBlock
    fake_module.DilatedBottleneck = DilatedBottleneck
    sys.modules["ultralytics.nn.modules.g2l_crm"] = fake_module

    # Also need to inject into tasks module for globals() lookup
    tasks.__dict__["G2L_CRM"] = G2L_CRM

    # Monkey-patch parse_model to include G2L_CRM in base_modules and repeat_modules
    _patch_parse_model(tasks)


def _patch_parse_model(tasks):
    """Patch parse_model to include G2L_CRM in the module sets."""
    import functools

    original_parse_model = tasks.parse_model

    @functools.wraps(original_parse_model)
    def patched_parse_model(d, ch, verbose=True):
        # Temporarily inject G2L_CRM into the function's globals
        # This is a hack, but necessary because parse_model uses globals()[m]
        old_globals = original_parse_model.__globals__.copy()
        original_parse_model.__globals__["G2L_CRM"] = G2L_CRM

        try:
            result = original_parse_model(d, ch, verbose)
        finally:
            # Restore original globals
            if "G2L_CRM" not in old_globals:
                original_parse_model.__globals__.pop("G2L_CRM", None)

        return result

    tasks.parse_model = patched_parse_model


# Apply patch on import
_patch_ultralytics()

# Public API
from .model import DocLayoutModel

__all__ = ["DocLayoutModel", "G2L_CRM"]
__version__ = "0.1.0"