Miaode's picture
Add files using upload-large-folder tool
e020674 verified
import importlib
import importlib.util
import sys
import types
import os
from dataflow.logger import get_logger
from pathlib import Path
from rich.console import Console
from rich.table import Table
import ast
from pathlib import Path
def generate_import_structure_from_type_checking(source_file: str, base_path: str) -> dict:
source = Path(source_file).read_text(encoding="utf-8")
tree = ast.parse(source)
import_structure = {}
for node in ast.walk(tree):
if isinstance(node, ast.If) and getattr(node.test, 'id', '') == 'TYPE_CHECKING':
for subnode in node.body:
if isinstance(subnode, ast.ImportFrom):
module_rel = subnode.module.replace(".", "/")
for alias in subnode.names:
name = alias.name
module_file = str(Path(base_path) / f"{module_rel}.py")
import_structure[name] = (module_file, name)
return import_structure
class Registry():
"""
The registry that provides name -> object mapping, to support third-party
users' custom modules.
To create a registry (e.g. a backbone registry):
.. code-block:: python
BACKBONE_REGISTRY = Registry('BACKBONE')
To register an object:
.. code-block:: python
@BACKBONE_REGISTRY.register()
class MyBackbone():
...
Or:
.. code-block:: python
BACKBONE_REGISTRY.register(MyBackbone)
"""
def __init__(self, name, sub_modules: list[str] = []):
"""
Args:
name (str): the name of this registry
"""
self._name = name
self._obj_map = {}
if len(sub_modules) > 0:
self.loader_map = dict(zip(sub_modules, [None] * len(sub_modules)))
def _init_loaders(self):
for module_name in self.loader_map.keys():
module_path = f"dataflow.{self._name}.{module_name}"
self.loader_map[module_name] = importlib.import_module(module_path)
def _do_register(self, name, obj):
if name not in self._obj_map:
self._obj_map[name] = obj
def register(self, obj=None):
"""
Register the given object under the the name `obj.__name__`.
Can be used as either a decorator or not.
See docstring of this class for usage.
"""
if obj is None:
# used as a decorator
def deco(func_or_class):
name = func_or_class.__name__
self._do_register(name, func_or_class)
return func_or_class
return deco
# used as a function call
name = obj.__name__
self._do_register(name, obj)
def get(self, name):
ret = self._obj_map.get(name)
logger = get_logger()
if ret is None:
if None in self.loader_map.values():
self._init_loaders()
for module_lib in self.loader_map.values():
# module_path = "dataflow.operators." + x
try:
# module_lib = importlib.import_module(module_path)
clss = getattr(module_lib, name)
self._obj_map[name] = clss
return clss
except AttributeError as e:
logger.debug(f"{str(e)}")
continue
except Exception as e:
raise e
logger.error(f"No object named '{name}' found in '{self._name}' registry!")
raise KeyError(f"No object named '{name}' found in '{self._name}' registry!")
if ret is None:
logger.error(f"No object named '{name}' found in '{self._name}' registry!")
assert ret is not None, f"No object named '{name}' found in '{self._name}' registry!"
return ret
def __contains__(self, name):
return name in self._obj_map
def __iter__(self):
return iter(self._obj_map.items())
def keys(self):
return self._obj_map.keys()
def __repr__(self):
table = Table(title=f'Registry of {self._name}')
table.add_column('Names', justify='left', style='cyan')
table.add_column('Objects', justify='left', style='green')
for name, obj in sorted(self._obj_map.items()):
table.add_row(name, str(obj))
console = Console()
with console.capture() as capture:
console.print(table, end='')
return capture.get()
def _get_all(self):
if None in self.loader_map.values():
self._init_loaders()
for loader in self.loader_map.values():
loader._import_all()
def get_obj_map(self):
"""
Get the object map of the registry.
"""
return self._obj_map
OPERATOR_REGISTRY = Registry(name='operators', sub_modules=['eval', 'filter', 'generate', 'refine', 'conversations'])
class LazyLoader(types.ModuleType):
def __init__(self, name, path, import_structure):
"""
初始化 LazyLoader 模块。
:param name: 模块名称
:param import_structure: 定义类名到文件路径的映射字典
"""
super().__init__(name)
self._import_structure = import_structure
self._loaded_classes = {}
self._base_folder = Path(__file__).resolve().parents[2]
self.__path__ = [path]
self.__all__ = list(import_structure.keys())
def _import_all(self):
for cls_name in self.__all__:
self.__getattr__(cls_name)
def _load_class_from_file(self, file_path, class_name):
"""
从指定文件中加载类。
:param file_path: 脚本文件的路径
:param class_name: 类的名字
:return: 类对象
"""
p = Path(file_path)
if p.is_absolute():
abs_file_path = str(p)
else:
abs_file_path = str(Path(self._base_folder) / p)
if not os.path.exists(abs_file_path):
raise FileNotFoundError(abs_file_path)
rel_path = Path(abs_file_path).relative_to(self._base_folder)
# 去掉后缀得到 ('dataflow', 'operators', 'generate', ... , 'question_generator')
rel_parts = rel_path.with_suffix('').parts
prefix_parts = tuple(self.__name__.split('.'))
if rel_parts[:len(prefix_parts)] == prefix_parts:
rel_parts = rel_parts[len(prefix_parts):]
mod_name = '.'.join((*prefix_parts, *rel_parts))
logger = get_logger()
# 动态加载模块
try:
spec = importlib.util.spec_from_file_location(mod_name, abs_file_path)
logger.debug(f"LazyLoader {self.__path__} successfully imported spec {spec.__str__()}")
module = importlib.util.module_from_spec(spec)
sys.modules[mod_name] = module
logger.debug(f"LazyLoader {self.__path__} successfully imported module {module.__str__()} from spec {spec.__str__()}")
spec.loader.exec_module(module)
except Exception as e:
logger.error(f"{e.__str__()}")
raise e
# 提取类
if not hasattr(module, class_name):
raise AttributeError(f"Class {class_name} not found in {abs_file_path}")
return getattr(module, class_name)
def __getattr__(self, item):
"""
动态加载类。
:param item: 类名
:return: 动态加载的类对象
"""
logger = get_logger()
if item in self._loaded_classes:
cls = self._loaded_classes[item]
logger.debug(f"Lazyloader {self.__path__} got cached class {cls}")
return cls
# 从映射结构中获取文件路径和类名
if item in self._import_structure:
file_path, class_name = self._import_structure[item]
logger.info(f"Lazyloader {self.__path__} trying to import {item} ")
cls = self._load_class_from_file(file_path, class_name)
logger.debug(f"Lazyloader {self.__path__} got and cached class {cls}")
self._loaded_classes[item] = cls
return cls
logger.debug(f"Module {self.__name__} has no attribute {item}")
raise AttributeError(f"Module {self.__name__} has no attribute {item}")