File size: 1,461 Bytes
795e71e
 
 
6159bde
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
795e71e
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
"""Modified from https://github.com/khanrc/honeybee
"""

import os

from peft.utils import ModulesToSaveWrapper
from transformers import logging


def get_cache_dir():
    DEFAULT_HF_HOME = "~/.cache/huggingface"
    cache_dir = os.environ.get("HF_HOME", DEFAULT_HF_HOME)

    return cache_dir


def check_local_file(model_name_or_path):
    cache_dir = get_cache_dir()
    file_name = os.path.join(
        cache_dir, f"models--{model_name_or_path.replace('/', '--')}"
    )
    local_files_only = os.path.exists(file_name)
    file_name = file_name if local_files_only else model_name_or_path
    return local_files_only, file_name


def unwrap_peft(layer):
    """This function is designed for the purpose of checking dtype of model or fetching model configs."""
    if isinstance(layer, ModulesToSaveWrapper):
        return layer.original_module
    else:
        return layer


class transformers_log_level:
    """https://github.com/huggingface/transformers/issues/5421#issuecomment-1317784733
    Temporary set log level for transformers
    """

    orig_log_level: int
    log_level: int

    def __init__(self, log_level: int):
        self.log_level = log_level
        self.orig_log_level = logging.get_verbosity()

    def __enter__(self):
        logging.set_verbosity(self.log_level)

    def __exit__(self, type, value, trace_back):
        logging.set_verbosity(self.orig_log_level)


def get_rank():
    return int(os.environ.get("RANK", 0))