Spaces:
Sleeping
Sleeping
| import ast | |
| import re | |
| import sys | |
| import warnings | |
| from importlib.abc import Loader, MetaPathFinder | |
| from importlib.machinery import ModuleSpec | |
| from importlib.util import spec_from_loader | |
| from types import ModuleType | |
| from typing import TYPE_CHECKING, Sequence, Union, Callable, Iterator, TypeVar, Any, cast | |
| from aworld.trace.base import log_trace_error | |
| from .rewrite_ast import compile_source | |
| if TYPE_CHECKING: | |
| from .context_manager import TraceManager | |
| class AutoTraceModule: | |
| """A class that represents a module being imported that should maybe be traced automatically.""" | |
| def __init__(self, module_name: str) -> None: | |
| self._module_name = module_name | |
| """Fully qualified absolute name of the module being imported.""" | |
| def need_auto_trace(self, prefix: Union[str, Sequence[str]]) -> bool: | |
| """ | |
| Check if the module name starts with the given prefix. | |
| """ | |
| if isinstance(prefix, str): | |
| prefix = (prefix,) | |
| pattern = '|'.join([get_module_pattern(p) for p in prefix]) | |
| return bool(re.match(pattern, self._module_name)) | |
| class TraceImportFinder(MetaPathFinder): | |
| """A class that implements the `find_spec` method of the `MetaPathFinder` protocol.""" | |
| def __init__(self, trace_manager: "TraceManager", module_funcs: Callable[[AutoTraceModule], bool], | |
| min_duration_ns: int) -> None: | |
| self._trace_manager = trace_manager | |
| self._modules_filter = module_funcs | |
| self._min_duration_ns = min_duration_ns | |
| def _find_plain_specs( | |
| self, fullname: str, path: Sequence[str] = None, target: ModuleType = None | |
| ) -> Iterator[ModuleSpec]: | |
| """Yield module specs returned by other finders on `sys.meta_path`.""" | |
| for finder in sys.meta_path: | |
| # Skip this finder or any like it to avoid infinite recursion. | |
| if isinstance(finder, TraceImportFinder): | |
| continue | |
| try: | |
| plain_spec = finder.find_spec(fullname, path, target) | |
| except Exception: # pragma: no cover | |
| continue | |
| if plain_spec: | |
| yield plain_spec | |
| def find_spec(self, fullname: str, path: Sequence[str], target=None) -> None: | |
| """Find the spec for the given module name.""" | |
| for plain_spec in self._find_plain_specs(fullname, path, target): | |
| # Get module specs returned by other finders on `sys.meta_path` | |
| get_source = getattr(plain_spec.loader, 'get_source', None) | |
| if not callable(get_source): | |
| continue | |
| try: | |
| source = cast(str, get_source(fullname)) | |
| except Exception: | |
| continue | |
| if not source: | |
| continue | |
| filename = plain_spec.origin | |
| if not filename: | |
| try: | |
| filename = cast('str | None', plain_spec.loader.get_filename(fullname)) | |
| except Exception: | |
| pass | |
| filename = filename or f'<{fullname}>' | |
| if not self._modules_filter(AutoTraceModule(fullname)): | |
| return None | |
| try: | |
| tree = ast.parse(source) | |
| except Exception: | |
| # Invalid source code. Try another one. | |
| continue | |
| try: | |
| execute = compile_source(tree, filename, fullname, self._trace_manager, self._min_duration_ns) | |
| except Exception: # pragma: no cover | |
| log_trace_error() | |
| return None | |
| loader = AutoTraceLoader(plain_spec, execute) | |
| return spec_from_loader(fullname, loader) | |
| class AutoTraceLoader(Loader): | |
| """ | |
| A class that implements the `exec_module` method of the `Loader` protocol. | |
| """ | |
| def __init__(self, plain_spec: ModuleSpec, execute: Callable[[dict[str, Any]], None]) -> None: | |
| self._plain_spec = plain_spec | |
| self._execute = execute | |
| def exec_module(self, module: ModuleType): | |
| """Execute a modified AST of the module's source code in the module's namespace. | |
| """ | |
| self._execute(module.__dict__) | |
| def create_module(self, spec: ModuleSpec): | |
| return None | |
| def get_code(self, _name: str): | |
| """`python -m` uses the `runpy` module which calls this method instead of going through the normal protocol. | |
| So return some code which can be executed with the module namespace. | |
| Here `__loader__` will be this object, i.e. `self`. | |
| source = '__loader__.execute(globals())' | |
| return compile(source, '<string>', 'exec', dont_inherit=True) | |
| """ | |
| def __getattr__(self, item: str): | |
| """Forward some methods to the plain spec's loader (likely a `SourceFileLoader`) if they exist.""" | |
| if item in {'get_filename', 'is_package'}: | |
| return getattr(self.plain_spec.loader, item) | |
| raise AttributeError(item) | |
| def convert_to_modules_func(modules: Sequence[str]) -> Callable[[AutoTraceModule], bool]: | |
| """Convert a sequence of module names to a function that checks if a module name starts with any of the given module names. | |
| """ | |
| return lambda module: module.need_auto_trace(modules) | |
| def get_module_pattern(module: str): | |
| """ | |
| Get the regex pattern for the given module name. | |
| """ | |
| if not re.match(r'[\w.]+$', module, re.UNICODE): | |
| return module | |
| module = re.escape(module) | |
| return rf'{module}($|\.)' | |
| def install_auto_tracing(trace_manager: "TraceManager", | |
| modules: Union[Sequence[str], | |
| Callable[[AutoTraceModule], bool]], | |
| min_duration_seconds: float | |
| ) -> None: | |
| """ | |
| Automatically trace the execution of a function. | |
| """ | |
| if isinstance(modules, Sequence): | |
| module_funcs = convert_to_modules_func(modules) | |
| else: | |
| module_funcs = modules | |
| if not callable(module_funcs): | |
| raise TypeError('modules must be a list of strings or a callable') | |
| for module in list(sys.modules.values()): | |
| try: | |
| auto_trace_module = AutoTraceModule(module.__name__) | |
| except Exception: | |
| continue | |
| if module_funcs(auto_trace_module): | |
| warnings.warn(f'The module {module.__name__!r} matches modules to trace, but it has already been imported. ' | |
| f'Call `auto_tracing` earlier', | |
| stacklevel=2, | |
| ) | |
| min_duration_ns = int(min_duration_seconds * 1_000_000_000) | |
| trace_manager = trace_manager.new_manager('auto_tracing') | |
| finder = TraceImportFinder(trace_manager, module_funcs, min_duration_ns) | |
| sys.meta_path.insert(0, finder) | |
| T = TypeVar('T') | |
| def not_auto_trace(x: T) -> T: | |
| """Decorator to prevent a function/class from being traced by `auto_tracing`""" | |
| return x | |