DeepSolanaCoder
/
DeepSeek-Coder-main
/finetune
/venv
/lib
/python3.12
/site-packages
/datasets
/utils
/patching.py
| from importlib import import_module | |
| from .logging import get_logger | |
| logger = get_logger(__name__) | |
| class _PatchedModuleObj: | |
| """Set all the modules components as attributes of the _PatchedModuleObj object.""" | |
| def __init__(self, module, attrs=None): | |
| attrs = attrs or [] | |
| if module is not None: | |
| for key in module.__dict__: | |
| if key in attrs or not key.startswith("__"): | |
| setattr(self, key, getattr(module, key)) | |
| self._original_module = module._original_module if isinstance(module, _PatchedModuleObj) else module | |
| class patch_submodule: | |
| """ | |
| Patch a submodule attribute of an object, by keeping all other submodules intact at all levels. | |
| Example:: | |
| >>> import importlib | |
| >>> from datasets.load import dataset_module_factory | |
| >>> from datasets.streaming import patch_submodule, xjoin | |
| >>> | |
| >>> dataset_module = dataset_module_factory("snli") | |
| >>> snli_module = importlib.import_module(dataset_module.module_path) | |
| >>> patcher = patch_submodule(snli_module, "os.path.join", xjoin) | |
| >>> patcher.start() | |
| >>> assert snli_module.os.path.join is xjoin | |
| """ | |
| _active_patches = [] | |
| def __init__(self, obj, target: str, new, attrs=None): | |
| self.obj = obj | |
| self.target = target | |
| self.new = new | |
| self.key = target.split(".")[0] | |
| self.original = {} | |
| self.attrs = attrs or [] | |
| def __enter__(self): | |
| *submodules, target_attr = self.target.split(".") | |
| # Patch modules: | |
| # it's used to patch attributes of submodules like "os.path.join"; | |
| # in this case we need to patch "os" and "os.path" | |
| for i in range(len(submodules)): | |
| try: | |
| submodule = import_module(".".join(submodules[: i + 1])) | |
| except ModuleNotFoundError: | |
| continue | |
| # We iterate over all the globals in self.obj in case we find "os" or "os.path" | |
| for attr in self.obj.__dir__(): | |
| obj_attr = getattr(self.obj, attr) | |
| # We don't check for the name of the global, but rather if its value *is* "os" or "os.path". | |
| # This allows to patch renamed modules like "from os import path as ospath". | |
| if obj_attr is submodule or ( | |
| isinstance(obj_attr, _PatchedModuleObj) and obj_attr._original_module is submodule | |
| ): | |
| self.original[attr] = obj_attr | |
| # patch at top level | |
| setattr(self.obj, attr, _PatchedModuleObj(obj_attr, attrs=self.attrs)) | |
| patched = getattr(self.obj, attr) | |
| # construct lower levels patches | |
| for key in submodules[i + 1 :]: | |
| setattr(patched, key, _PatchedModuleObj(getattr(patched, key, None), attrs=self.attrs)) | |
| patched = getattr(patched, key) | |
| # finally set the target attribute | |
| setattr(patched, target_attr, self.new) | |
| # Patch attribute itself: | |
| # it's used for builtins like "open", | |
| # and also to patch "os.path.join" we may also need to patch "join" | |
| # itself if it was imported as "from os.path import join". | |
| if submodules: # if it's an attribute of a submodule like "os.path.join" | |
| try: | |
| attr_value = getattr(import_module(".".join(submodules)), target_attr) | |
| except (AttributeError, ModuleNotFoundError): | |
| return | |
| # We iterate over all the globals in self.obj in case we find "os.path.join" | |
| for attr in self.obj.__dir__(): | |
| # We don't check for the name of the global, but rather if its value *is* "os.path.join". | |
| # This allows to patch renamed attributes like "from os.path import join as pjoin". | |
| if getattr(self.obj, attr) is attr_value: | |
| self.original[attr] = getattr(self.obj, attr) | |
| setattr(self.obj, attr, self.new) | |
| elif target_attr in globals()["__builtins__"]: # if it'a s builtin like "open" | |
| self.original[target_attr] = globals()["__builtins__"][target_attr] | |
| setattr(self.obj, target_attr, self.new) | |
| else: | |
| raise RuntimeError(f"Tried to patch attribute {target_attr} instead of a submodule.") | |
| def __exit__(self, *exc_info): | |
| for attr in list(self.original): | |
| setattr(self.obj, attr, self.original.pop(attr)) | |
| def start(self): | |
| """Activate a patch.""" | |
| self.__enter__() | |
| self._active_patches.append(self) | |
| def stop(self): | |
| """Stop an active patch.""" | |
| try: | |
| self._active_patches.remove(self) | |
| except ValueError: | |
| # If the patch hasn't been started this will fail | |
| return None | |
| return self.__exit__() | |