| | import io |
| | from contextlib import contextmanager |
| |
|
| | import mmengine.fileio as fileio |
| | from mmengine.fileio import LocalBackend, get_file_backend |
| |
|
| |
|
| | def patch_func(module, fn_name_to_wrap): |
| | backup = getattr(patch_func, '_backup', []) |
| | fn_to_wrap = getattr(module, fn_name_to_wrap) |
| |
|
| | def wrap(fn_new): |
| | setattr(module, fn_name_to_wrap, fn_new) |
| | backup.append((module, fn_name_to_wrap, fn_to_wrap)) |
| | setattr(fn_new, '_fallback', fn_to_wrap) |
| | setattr(patch_func, '_backup', backup) |
| | return fn_new |
| |
|
| | return wrap |
| |
|
| |
|
| | @contextmanager |
| | def patch_fileio(global_vars=None): |
| | if getattr(patch_fileio, '_patched', False): |
| | |
| | yield |
| | return |
| | import builtins |
| |
|
| | @patch_func(builtins, 'open') |
| | def open(file, mode='r', *args, **kwargs): |
| | backend = get_file_backend(file) |
| | if isinstance(backend, LocalBackend): |
| | return open._fallback(file, mode, *args, **kwargs) |
| | if 'b' in mode: |
| | return io.BytesIO(backend.get(file, *args, **kwargs)) |
| | else: |
| | return io.StringIO(backend.get_text(file, *args, **kwargs)) |
| |
|
| | if global_vars is not None and 'open' in global_vars: |
| | bak_open = global_vars['open'] |
| | global_vars['open'] = builtins.open |
| |
|
| | import os |
| |
|
| | @patch_func(os.path, 'join') |
| | def join(a, *paths): |
| | backend = get_file_backend(a) |
| | if isinstance(backend, LocalBackend): |
| | return join._fallback(a, *paths) |
| | paths = [item for item in paths if len(item) > 0] |
| | return backend.join_path(a, *paths) |
| |
|
| | @patch_func(os.path, 'isdir') |
| | def isdir(path): |
| | backend = get_file_backend(path) |
| | if isinstance(backend, LocalBackend): |
| | return isdir._fallback(path) |
| | return backend.isdir(path) |
| |
|
| | @patch_func(os.path, 'isfile') |
| | def isfile(path): |
| | backend = get_file_backend(path) |
| | if isinstance(backend, LocalBackend): |
| | return isfile._fallback(path) |
| | return backend.isfile(path) |
| |
|
| | @patch_func(os.path, 'exists') |
| | def exists(path): |
| | backend = get_file_backend(path) |
| | if isinstance(backend, LocalBackend): |
| | return exists._fallback(path) |
| | return backend.exists(path) |
| |
|
| | @patch_func(os, 'listdir') |
| | def listdir(path): |
| | backend = get_file_backend(path) |
| | if isinstance(backend, LocalBackend): |
| | return listdir._fallback(path) |
| | return backend.list_dir_or_file(path) |
| |
|
| | import filecmp |
| |
|
| | @patch_func(filecmp, 'cmp') |
| | def cmp(f1, f2, *args, **kwargs): |
| | with fileio.get_local_path(f1) as f1, fileio.get_local_path(f2) as f2: |
| | return cmp._fallback(f1, f2, *args, **kwargs) |
| |
|
| | import shutil |
| |
|
| | @patch_func(shutil, 'copy') |
| | def copy(src, dst, **kwargs): |
| | backend = get_file_backend(src) |
| | if isinstance(backend, LocalBackend): |
| | return copy._fallback(src, dst, **kwargs) |
| | return backend.copyfile_to_local(str(src), str(dst)) |
| |
|
| | import torch |
| |
|
| | @patch_func(torch, 'load') |
| | def load(f, *args, **kwargs): |
| | if isinstance(f, str): |
| | f = io.BytesIO(fileio.get(f)) |
| | return load._fallback(f, *args, **kwargs) |
| |
|
| | try: |
| | setattr(patch_fileio, '_patched', True) |
| | yield |
| | finally: |
| | for patched_fn in patch_func._backup: |
| | (module, fn_name_to_wrap, fn_to_wrap) = patched_fn |
| | setattr(module, fn_name_to_wrap, fn_to_wrap) |
| | if global_vars is not None and 'open' in global_vars: |
| | global_vars['open'] = bak_open |
| | setattr(patch_fileio, '_patched', False) |
| |
|
| |
|
| | def patch_hf_auto_model(cache_dir=None): |
| | if hasattr('patch_hf_auto_model', '_patched'): |
| | return |
| |
|
| | from transformers.modeling_utils import PreTrainedModel |
| | from transformers.models.auto.auto_factory import _BaseAutoModelClass |
| |
|
| | ori_model_pt = PreTrainedModel.from_pretrained |
| |
|
| | @classmethod |
| | def model_pt(cls, pretrained_model_name_or_path, *args, **kwargs): |
| | kwargs['cache_dir'] = cache_dir |
| | if not isinstance(get_file_backend(pretrained_model_name_or_path), |
| | LocalBackend): |
| | kwargs['local_files_only'] = True |
| | if cache_dir is not None and not isinstance( |
| | get_file_backend(cache_dir), LocalBackend): |
| | kwargs['local_files_only'] = True |
| |
|
| | with patch_fileio(): |
| | res = ori_model_pt.__func__(cls, pretrained_model_name_or_path, |
| | *args, **kwargs) |
| | return res |
| |
|
| | PreTrainedModel.from_pretrained = model_pt |
| |
|
| | |
| | |
| | for auto_class in [ |
| | _BaseAutoModelClass, *_BaseAutoModelClass.__subclasses__() |
| | ]: |
| | ori_auto_pt = auto_class.from_pretrained |
| |
|
| | @classmethod |
| | def auto_pt(cls, pretrained_model_name_or_path, *args, **kwargs): |
| | kwargs['cache_dir'] = cache_dir |
| | if not isinstance(get_file_backend(pretrained_model_name_or_path), |
| | LocalBackend): |
| | kwargs['local_files_only'] = True |
| | if cache_dir is not None and not isinstance( |
| | get_file_backend(cache_dir), LocalBackend): |
| | kwargs['local_files_only'] = True |
| |
|
| | with patch_fileio(): |
| | res = ori_auto_pt.__func__(cls, pretrained_model_name_or_path, |
| | *args, **kwargs) |
| | return res |
| |
|
| | auto_class.from_pretrained = auto_pt |
| |
|
| | patch_hf_auto_model._patched = True |
| |
|