"""Modified from https://github.com/khanrc/honeybee """ import os from peft.utils import ModulesToSaveWrapper from transformers import logging def get_cache_dir(): DEFAULT_HF_HOME = "~/.cache/huggingface" cache_dir = os.environ.get("HF_HOME", DEFAULT_HF_HOME) return cache_dir def check_local_file(model_name_or_path): cache_dir = get_cache_dir() file_name = os.path.join( cache_dir, f"models--{model_name_or_path.replace('/', '--')}" ) local_files_only = os.path.exists(file_name) file_name = file_name if local_files_only else model_name_or_path return local_files_only, file_name def unwrap_peft(layer): """This function is designed for the purpose of checking dtype of model or fetching model configs.""" if isinstance(layer, ModulesToSaveWrapper): return layer.original_module else: return layer class transformers_log_level: """https://github.com/huggingface/transformers/issues/5421#issuecomment-1317784733 Temporary set log level for transformers """ orig_log_level: int log_level: int def __init__(self, log_level: int): self.log_level = log_level self.orig_log_level = logging.get_verbosity() def __enter__(self): logging.set_verbosity(self.log_level) def __exit__(self, type, value, trace_back): logging.set_verbosity(self.orig_log_level) def get_rank(): return int(os.environ.get("RANK", 0))