File size: 6,259 Bytes
485127c | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 | from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import time
import os
def _repo_root():
return os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
def _load_env_file():
path = os.path.join(_repo_root(), '.env')
if not os.path.isfile(path):
return
with open(path, encoding='utf-8') as f:
for line in f:
line = line.strip()
if not line or line.startswith('#') or '=' not in line:
continue
key, _, val = line.partition('=')
key, val = key.strip(), val.strip().strip('"').strip("'")
if key and key not in os.environ:
os.environ[key] = val
def _ensure_hf_auth():
_load_env_file()
token = os.environ.get('HF_TOKEN') or os.environ.get('HUGGING_FACE_HUB_TOKEN')
if not token:
return
try:
from huggingface_hub import login
login(token=token, add_to_git_credential=False)
except Exception as e:
print(f'Warning: Hugging Face login failed ({e})')
def _is_hub_rate_limit(err):
s = str(err).lower()
return '429' in str(err) or 'too many requests' in s
def _with_hub_retries(load_fn, attempts=6):
delay = 15.0
last = None
for i in range(attempts):
try:
return load_fn()
except Exception as e:
last = e
if _is_hub_rate_limit(e) and i < attempts - 1:
print(
f'Hugging Face rate limit (429); waiting {delay:.0f}s before retry '
f'({i + 1}/{attempts}). Use HF_TOKEN in .env for higher limits.'
)
time.sleep(delay)
delay = min(delay * 1.5, 120.0)
continue
raise
raise last
_ensure_hf_auth()
def model_max_length(model) -> int:
"""Max positions for truncation (e.g. GPT-2 n_positions=1024). Avoids CUDA gather errors on long texts."""
cfg = model.config
n = getattr(cfg, "n_positions", None)
if n is not None:
return int(n)
n = getattr(cfg, "max_position_embeddings", None)
if n is not None and int(n) < 1_000_000:
return int(n)
return 1024
def from_pretrained(cls, model_name, kwargs, cache_dir):
# use local model if it exists
local_path = os.path.join(cache_dir, 'local.' + model_name.replace("/", "_"))
if os.path.exists(local_path):
return cls.from_pretrained(local_path, **kwargs)
def _load():
return cls.from_pretrained(model_name, **kwargs, cache_dir=cache_dir)
return _with_hub_retries(_load)
# predefined models
model_fullnames = { 'gpt2': 'gpt2',
'gpt2-xl': 'gpt2-xl',
'opt-2.7b': 'facebook/opt-2.7b',
'gpt-neo-2.7B': 'EleutherAI/gpt-neo-2.7B',
'gpt-j-6B': 'EleutherAI/gpt-j-6B',
'gpt-neox-20b': 'EleutherAI/gpt-neox-20b',
'mgpt': 'sberbank-ai/mGPT',
'pubmedgpt': 'stanford-crfm/pubmedgpt',
'mt5-xl': 'google/mt5-xl',
'qwen-7b': 'Qwen/Qwen2.5-7B',
'qwen-7b-instruct': 'Qwen/Qwen2.5-7B-Instruct',
'mistralai-7b': 'mistralai/Mistral-7B-v0.1',
'mistralai-7b-instruct': 'mistralai/Mistral-7B-Instruct-v0.3',
'llama3-8b': 'meta-llama/Meta-Llama-3-8B',
'llama3-8b-instruct': 'meta-llama/Meta-Llama-3-8B-Instruct',
'falcon-7b': 'tiiuae/falcon-7b',
'falcon-7b-instruct': 'tiiuae/falcon-7b-instruct',
'llama2-13b': 'TheBloke/Llama-2-13B-fp16',
'llama2-13b-chat': 'meta-llama/Llama-2-13b-chat-hf',
'gemma-9b': 'google/gemma-2-9b',
'gemma-9b-instruct': 'google/gemma-2-9b-it',
'bloom-7b1': 'bigscience/bloom-7b1',
'opt-13b': 'facebook/opt-13b',
'pythia-12b': 'EleutherAI/pythia-12b',
}
float16_models = ['gpt-neo-2.7B', 'gpt-j-6B', 'gpt-neox-20b', 'falcon-7b', 'falcon-7b-instruct', 'qwen-7b', 'qwen-7b-instruct', 'mistralai-7b', 'mistralai-7b-instruct', 'llama3-8b', 'llama3-8b-instruct', 'gemma-9b', 'gemma-9b-instruct', 'llama2-13b', 'bloom-7b1', 'opt-13b', 'pythia-12b', 'llama2-13b-chat']
def get_model_fullname(model_name):
return model_fullnames[model_name] if model_name in model_fullnames else model_name
def load_model(model_name, device, cache_dir, torch_dtype=None):
model_fullname = get_model_fullname(model_name)
print(f'Loading model {model_fullname}...')
model_kwargs = {}
if model_name in float16_models:
model_kwargs.update(dict(torch_dtype=torch.float16))
if 'gpt-j' in model_name:
model_kwargs.update(dict(revision='float16'))
if torch_dtype is not None:
model_kwargs.update(dict(torch_dtype=torch_dtype))
model = from_pretrained(AutoModelForCausalLM, model_fullname, model_kwargs, cache_dir)
print('Moving model to GPU...', end='', flush=True)
start = time.time()
model.to(device)
print(f'DONE ({time.time() - start:.2f}s)')
return model
def load_tokenizer(model_name, cache_dir):
model_fullname = get_model_fullname(model_name)
optional_tok_kwargs = {}
if "facebook/opt-" in model_fullname:
print("Using non-fast tokenizer for OPT")
optional_tok_kwargs['fast'] = False
optional_tok_kwargs['padding_side'] = 'right'
base_tokenizer = from_pretrained(AutoTokenizer, model_fullname, optional_tok_kwargs, cache_dir=cache_dir)
if base_tokenizer.pad_token_id is None:
base_tokenizer.pad_token_id = base_tokenizer.eos_token_id
if '13b' in model_fullname:
base_tokenizer.pad_token_id = 0
return base_tokenizer
if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--model_name', type=str, default="mistralai-7b-instruct")
parser.add_argument('--cache_dir', type=str, default="../cache")
args = parser.parse_args()
load_tokenizer(args.model_name, args.cache_dir)
load_model(args.model_name, 'cpu', args.cache_dir)
|