File size: 6,259 Bytes
485127c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import time
import os


def _repo_root():
    return os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))


def _load_env_file():
    path = os.path.join(_repo_root(), '.env')
    if not os.path.isfile(path):
        return
    with open(path, encoding='utf-8') as f:
        for line in f:
            line = line.strip()
            if not line or line.startswith('#') or '=' not in line:
                continue
            key, _, val = line.partition('=')
            key, val = key.strip(), val.strip().strip('"').strip("'")
            if key and key not in os.environ:
                os.environ[key] = val


def _ensure_hf_auth():
    _load_env_file()
    token = os.environ.get('HF_TOKEN') or os.environ.get('HUGGING_FACE_HUB_TOKEN')
    if not token:
        return
    try:
        from huggingface_hub import login
        login(token=token, add_to_git_credential=False)
    except Exception as e:
        print(f'Warning: Hugging Face login failed ({e})')


def _is_hub_rate_limit(err):
    s = str(err).lower()
    return '429' in str(err) or 'too many requests' in s


def _with_hub_retries(load_fn, attempts=6):
    delay = 15.0
    last = None
    for i in range(attempts):
        try:
            return load_fn()
        except Exception as e:
            last = e
            if _is_hub_rate_limit(e) and i < attempts - 1:
                print(
                    f'Hugging Face rate limit (429); waiting {delay:.0f}s before retry '
                    f'({i + 1}/{attempts}). Use HF_TOKEN in .env for higher limits.'
                )
                time.sleep(delay)
                delay = min(delay * 1.5, 120.0)
                continue
            raise
    raise last


_ensure_hf_auth()


def model_max_length(model) -> int:
    """Max positions for truncation (e.g. GPT-2 n_positions=1024). Avoids CUDA gather errors on long texts."""
    cfg = model.config
    n = getattr(cfg, "n_positions", None)
    if n is not None:
        return int(n)
    n = getattr(cfg, "max_position_embeddings", None)
    if n is not None and int(n) < 1_000_000:
        return int(n)
    return 1024


def from_pretrained(cls, model_name, kwargs, cache_dir):
    # use local model if it exists
    local_path = os.path.join(cache_dir, 'local.' + model_name.replace("/", "_"))
    if os.path.exists(local_path):
        return cls.from_pretrained(local_path, **kwargs)

    def _load():
        return cls.from_pretrained(model_name, **kwargs, cache_dir=cache_dir)

    return _with_hub_retries(_load)

# predefined models
model_fullnames = {  'gpt2': 'gpt2',
                     'gpt2-xl': 'gpt2-xl',
                     'opt-2.7b': 'facebook/opt-2.7b',
                     'gpt-neo-2.7B': 'EleutherAI/gpt-neo-2.7B',
                     'gpt-j-6B': 'EleutherAI/gpt-j-6B',
                     'gpt-neox-20b': 'EleutherAI/gpt-neox-20b',
                     'mgpt': 'sberbank-ai/mGPT',
                     'pubmedgpt': 'stanford-crfm/pubmedgpt',
                     'mt5-xl': 'google/mt5-xl',
                     'qwen-7b': 'Qwen/Qwen2.5-7B',
                     'qwen-7b-instruct': 'Qwen/Qwen2.5-7B-Instruct',
                     'mistralai-7b': 'mistralai/Mistral-7B-v0.1',
                     'mistralai-7b-instruct': 'mistralai/Mistral-7B-Instruct-v0.3',
                     'llama3-8b': 'meta-llama/Meta-Llama-3-8B',
                     'llama3-8b-instruct': 'meta-llama/Meta-Llama-3-8B-Instruct',
                     'falcon-7b': 'tiiuae/falcon-7b',
                     'falcon-7b-instruct': 'tiiuae/falcon-7b-instruct',
                     'llama2-13b': 'TheBloke/Llama-2-13B-fp16',
                     'llama2-13b-chat': 'meta-llama/Llama-2-13b-chat-hf', 
                     'gemma-9b': 'google/gemma-2-9b',
                     'gemma-9b-instruct': 'google/gemma-2-9b-it',
                     'bloom-7b1': 'bigscience/bloom-7b1',
                     'opt-13b': 'facebook/opt-13b',
                     'pythia-12b': 'EleutherAI/pythia-12b',
                     }
float16_models = ['gpt-neo-2.7B', 'gpt-j-6B', 'gpt-neox-20b', 'falcon-7b', 'falcon-7b-instruct', 'qwen-7b', 'qwen-7b-instruct', 'mistralai-7b', 'mistralai-7b-instruct', 'llama3-8b', 'llama3-8b-instruct', 'gemma-9b', 'gemma-9b-instruct', 'llama2-13b', 'bloom-7b1', 'opt-13b', 'pythia-12b', 'llama2-13b-chat']

def get_model_fullname(model_name):
    return model_fullnames[model_name] if model_name in model_fullnames else model_name

def load_model(model_name, device, cache_dir, torch_dtype=None):
    model_fullname = get_model_fullname(model_name)
    print(f'Loading model {model_fullname}...')
    model_kwargs = {}
    if model_name in float16_models:
        model_kwargs.update(dict(torch_dtype=torch.float16))
    if 'gpt-j' in model_name:
        model_kwargs.update(dict(revision='float16'))
    if torch_dtype is not None:
        model_kwargs.update(dict(torch_dtype=torch_dtype))
    model = from_pretrained(AutoModelForCausalLM, model_fullname, model_kwargs, cache_dir)
    print('Moving model to GPU...', end='', flush=True)
    start = time.time()
    model.to(device)
    print(f'DONE ({time.time() - start:.2f}s)')
    return model

def load_tokenizer(model_name, cache_dir):
    model_fullname = get_model_fullname(model_name)
    optional_tok_kwargs = {}
    if "facebook/opt-" in model_fullname:
        print("Using non-fast tokenizer for OPT")
        optional_tok_kwargs['fast'] = False
    optional_tok_kwargs['padding_side'] = 'right'
    base_tokenizer = from_pretrained(AutoTokenizer, model_fullname, optional_tok_kwargs, cache_dir=cache_dir)
    if base_tokenizer.pad_token_id is None:
        base_tokenizer.pad_token_id = base_tokenizer.eos_token_id
        if '13b' in model_fullname:
            base_tokenizer.pad_token_id = 0
    return base_tokenizer


if __name__ == '__main__':
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument('--model_name', type=str, default="mistralai-7b-instruct")
    parser.add_argument('--cache_dir', type=str, default="../cache")
    args = parser.parse_args()

    load_tokenizer(args.model_name, args.cache_dir)
    load_model(args.model_name, 'cpu', args.cache_dir)