File size: 8,353 Bytes
e020674
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
import importlib
import importlib.util
import sys
import types
import os
from dataflow.logger import get_logger
from pathlib import Path

from rich.console import Console
from rich.table import Table

import ast
from pathlib import Path

def generate_import_structure_from_type_checking(source_file: str, base_path: str) -> dict:
    source = Path(source_file).read_text(encoding="utf-8")
    tree = ast.parse(source)

    import_structure = {}

    for node in ast.walk(tree):
        if isinstance(node, ast.If) and getattr(node.test, 'id', '') == 'TYPE_CHECKING':
            for subnode in node.body:
                if isinstance(subnode, ast.ImportFrom):
                    module_rel = subnode.module.replace(".", "/")
                    for alias in subnode.names:
                        name = alias.name
                        module_file = str(Path(base_path) / f"{module_rel}.py")
                        import_structure[name] = (module_file, name)

    return import_structure


class Registry():
    """
    The registry that provides name -> object mapping, to support third-party
    users' custom modules.

    To create a registry (e.g. a backbone registry):

    .. code-block:: python

        BACKBONE_REGISTRY = Registry('BACKBONE')

    To register an object:

    .. code-block:: python

        @BACKBONE_REGISTRY.register()
        class MyBackbone():
            ...

    Or:

    .. code-block:: python

        BACKBONE_REGISTRY.register(MyBackbone)
    """

    def __init__(self, name, sub_modules: list[str] = []):
        """
        Args:
            name (str): the name of this registry
        """
        self._name = name
        self._obj_map = {}
        if len(sub_modules) > 0:
            self.loader_map = dict(zip(sub_modules, [None] * len(sub_modules)))
        
    def _init_loaders(self):
        for module_name in self.loader_map.keys():
            module_path = f"dataflow.{self._name}.{module_name}"
            self.loader_map[module_name] = importlib.import_module(module_path)

    def _do_register(self, name, obj):
        if name not in self._obj_map:
            self._obj_map[name] = obj

    def register(self, obj=None):
        """
        Register the given object under the the name `obj.__name__`.
        Can be used as either a decorator or not.
        See docstring of this class for usage.
        """
        if obj is None:
            # used as a decorator
            def deco(func_or_class):
                name = func_or_class.__name__
                self._do_register(name, func_or_class)
                return func_or_class

            return deco

        # used as a function call
        name = obj.__name__
        self._do_register(name, obj)

    def get(self, name):
        ret = self._obj_map.get(name)
        logger = get_logger()
        if ret is None:
            if None in self.loader_map.values():
                self._init_loaders()
            for module_lib in self.loader_map.values():
                # module_path = "dataflow.operators." + x
                try:
                    # module_lib = importlib.import_module(module_path)
                    clss = getattr(module_lib, name)
                    self._obj_map[name] = clss
                    return clss
                except AttributeError as e:
                    logger.debug(f"{str(e)}")
                    continue
                except Exception as e:
                    raise e
            logger.error(f"No object named '{name}' found in '{self._name}' registry!")
            raise KeyError(f"No object named '{name}' found in '{self._name}' registry!")

        if ret is None:
            logger.error(f"No object named '{name}' found in '{self._name}' registry!")
        assert ret is not None, f"No object named '{name}' found in '{self._name}' registry!"
        
        return ret

    def __contains__(self, name):
        return name in self._obj_map

    def __iter__(self):
        return iter(self._obj_map.items())

    def keys(self):
        return self._obj_map.keys()

    def __repr__(self):
        table = Table(title=f'Registry of {self._name}')
        table.add_column('Names', justify='left', style='cyan')
        table.add_column('Objects', justify='left', style='green')

        for name, obj in sorted(self._obj_map.items()):
            table.add_row(name, str(obj))

        console = Console()
        with console.capture() as capture:
            console.print(table, end='')

        return capture.get()

    def _get_all(self):
        if None in self.loader_map.values():
            self._init_loaders()
        for loader in self.loader_map.values():
            loader._import_all()

    def get_obj_map(self):
        """
        Get the object map of the registry.
        """
        return self._obj_map

OPERATOR_REGISTRY = Registry(name='operators', sub_modules=['eval', 'filter', 'generate', 'refine', 'conversations'])
class LazyLoader(types.ModuleType):

    def __init__(self, name, path, import_structure):
        """
        初始化 LazyLoader 模块。

        :param name: 模块名称
        :param import_structure: 定义类名到文件路径的映射字典
        """
        super().__init__(name)
        self._import_structure = import_structure
        self._loaded_classes = {}
        self._base_folder = Path(__file__).resolve().parents[2]
        self.__path__ = [path]
        self.__all__ = list(import_structure.keys())
        
    def _import_all(self):
        for cls_name in self.__all__:
            self.__getattr__(cls_name)

    def _load_class_from_file(self, file_path, class_name):
        """
        从指定文件中加载类。

        :param file_path: 脚本文件的路径
        :param class_name: 类的名字
        :return: 类对象
        """
        p = Path(file_path)
        if p.is_absolute():
            abs_file_path = str(p)
        else:
            abs_file_path = str(Path(self._base_folder) / p)
        if not os.path.exists(abs_file_path):
            raise FileNotFoundError(abs_file_path)
        rel_path = Path(abs_file_path).relative_to(self._base_folder)
        # 去掉后缀得到 ('dataflow', 'operators', 'generate', ... , 'question_generator')
        rel_parts = rel_path.with_suffix('').parts
        prefix_parts = tuple(self.__name__.split('.'))
        if rel_parts[:len(prefix_parts)] == prefix_parts:
            rel_parts = rel_parts[len(prefix_parts):]
        mod_name = '.'.join((*prefix_parts, *rel_parts))
        logger = get_logger()
        # 动态加载模块
        try:
            spec = importlib.util.spec_from_file_location(mod_name, abs_file_path)
            logger.debug(f"LazyLoader {self.__path__} successfully imported spec {spec.__str__()}")
            module = importlib.util.module_from_spec(spec)
            sys.modules[mod_name] = module
            logger.debug(f"LazyLoader {self.__path__} successfully imported module {module.__str__()} from spec {spec.__str__()}")
            spec.loader.exec_module(module)
        except Exception as e:
            logger.error(f"{e.__str__()}")
            raise e

        # 提取类
        if not hasattr(module, class_name):
            raise AttributeError(f"Class {class_name} not found in {abs_file_path}")
        return getattr(module, class_name)

    def __getattr__(self, item):
        """
        动态加载类。

        :param item: 类名
        :return: 动态加载的类对象
        """
        logger = get_logger()
        if item in self._loaded_classes:
            cls = self._loaded_classes[item]
            logger.debug(f"Lazyloader {self.__path__} got cached class {cls}")
            return cls
        # 从映射结构中获取文件路径和类名
        if item in self._import_structure:
            file_path, class_name = self._import_structure[item]
            logger.info(f"Lazyloader {self.__path__} trying to import {item} ")
            cls = self._load_class_from_file(file_path, class_name)
            logger.debug(f"Lazyloader {self.__path__} got and cached class {cls}")
            self._loaded_classes[item] = cls
            return cls
        logger.debug(f"Module {self.__name__} has no attribute {item}")
        raise AttributeError(f"Module {self.__name__} has no attribute {item}")