M4CXR-TNNLS / utils.py
jonggwon-park's picture
debug logger
bc4a81d
"""Modified from https://github.com/khanrc/honeybee
"""
import os
from peft.utils import ModulesToSaveWrapper
from transformers import logging
def get_cache_dir():
DEFAULT_HF_HOME = "~/.cache/huggingface"
cache_dir = os.environ.get("HF_HOME", DEFAULT_HF_HOME)
return cache_dir
def check_local_file(model_name_or_path):
cache_dir = get_cache_dir()
file_name = os.path.join(
cache_dir, f"models--{model_name_or_path.replace('/', '--')}"
)
local_files_only = os.path.exists(file_name)
file_name = file_name if local_files_only else model_name_or_path
return local_files_only, file_name
def unwrap_peft(layer):
"""This function is designed for the purpose of checking dtype of model or fetching model configs."""
if isinstance(layer, ModulesToSaveWrapper):
return layer.original_module
else:
return layer
class transformers_log_level:
"""https://github.com/huggingface/transformers/issues/5421#issuecomment-1317784733
Temporary set log level for transformers
"""
orig_log_level: int
log_level: int
def __init__(self, log_level: int):
self.log_level = log_level
self.orig_log_level = logging.get_verbosity()
def __enter__(self):
logging.set_verbosity(self.log_level)
def __exit__(self, type, value, trace_back):
logging.set_verbosity(self.orig_log_level)
def get_rank():
return int(os.environ.get("RANK", 0))