| """Modified from https://github.com/khanrc/honeybee |
| """ |
|
|
| import os |
|
|
| from peft.utils import ModulesToSaveWrapper |
| from transformers import logging |
|
|
|
|
| def get_cache_dir(): |
| DEFAULT_HF_HOME = "~/.cache/huggingface" |
| cache_dir = os.environ.get("HF_HOME", DEFAULT_HF_HOME) |
|
|
| return cache_dir |
|
|
|
|
| def check_local_file(model_name_or_path): |
| cache_dir = get_cache_dir() |
| file_name = os.path.join( |
| cache_dir, f"models--{model_name_or_path.replace('/', '--')}" |
| ) |
| local_files_only = os.path.exists(file_name) |
| file_name = file_name if local_files_only else model_name_or_path |
| return local_files_only, file_name |
|
|
|
|
| def unwrap_peft(layer): |
| """This function is designed for the purpose of checking dtype of model or fetching model configs.""" |
| if isinstance(layer, ModulesToSaveWrapper): |
| return layer.original_module |
| else: |
| return layer |
|
|
|
|
| class transformers_log_level: |
| """https://github.com/huggingface/transformers/issues/5421#issuecomment-1317784733 |
| Temporary set log level for transformers |
| """ |
|
|
| orig_log_level: int |
| log_level: int |
|
|
| def __init__(self, log_level: int): |
| self.log_level = log_level |
| self.orig_log_level = logging.get_verbosity() |
|
|
| def __enter__(self): |
| logging.set_verbosity(self.log_level) |
|
|
| def __exit__(self, type, value, trace_back): |
| logging.set_verbosity(self.orig_log_level) |
|
|
|
|
| def get_rank(): |
| return int(os.environ.get("RANK", 0)) |
|
|