File size: 1,461 Bytes
795e71e 6159bde 795e71e | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 | """Modified from https://github.com/khanrc/honeybee
"""
import os
from peft.utils import ModulesToSaveWrapper
from transformers import logging
def get_cache_dir():
DEFAULT_HF_HOME = "~/.cache/huggingface"
cache_dir = os.environ.get("HF_HOME", DEFAULT_HF_HOME)
return cache_dir
def check_local_file(model_name_or_path):
cache_dir = get_cache_dir()
file_name = os.path.join(
cache_dir, f"models--{model_name_or_path.replace('/', '--')}"
)
local_files_only = os.path.exists(file_name)
file_name = file_name if local_files_only else model_name_or_path
return local_files_only, file_name
def unwrap_peft(layer):
"""This function is designed for the purpose of checking dtype of model or fetching model configs."""
if isinstance(layer, ModulesToSaveWrapper):
return layer.original_module
else:
return layer
class transformers_log_level:
"""https://github.com/huggingface/transformers/issues/5421#issuecomment-1317784733
Temporary set log level for transformers
"""
orig_log_level: int
log_level: int
def __init__(self, log_level: int):
self.log_level = log_level
self.orig_log_level = logging.get_verbosity()
def __enter__(self):
logging.set_verbosity(self.log_level)
def __exit__(self, type, value, trace_back):
logging.set_verbosity(self.orig_log_level)
def get_rank():
return int(os.environ.get("RANK", 0))
|