| from transformers import AutoModelForCausalLM, AutoTokenizer |
| import torch |
| import time |
| import os |
|
|
|
|
| def _repo_root(): |
| return os.path.abspath(os.path.join(os.path.dirname(__file__), '..')) |
|
|
|
|
| def _load_env_file(): |
| path = os.path.join(_repo_root(), '.env') |
| if not os.path.isfile(path): |
| return |
| with open(path, encoding='utf-8') as f: |
| for line in f: |
| line = line.strip() |
| if not line or line.startswith('#') or '=' not in line: |
| continue |
| key, _, val = line.partition('=') |
| key, val = key.strip(), val.strip().strip('"').strip("'") |
| if key and key not in os.environ: |
| os.environ[key] = val |
|
|
|
|
| def _ensure_hf_auth(): |
| _load_env_file() |
| token = os.environ.get('HF_TOKEN') or os.environ.get('HUGGING_FACE_HUB_TOKEN') |
| if not token: |
| return |
| try: |
| from huggingface_hub import login |
| login(token=token, add_to_git_credential=False) |
| except Exception as e: |
| print(f'Warning: Hugging Face login failed ({e})') |
|
|
|
|
| def _is_hub_rate_limit(err): |
| s = str(err).lower() |
| return '429' in str(err) or 'too many requests' in s |
|
|
|
|
| def _with_hub_retries(load_fn, attempts=6): |
| delay = 15.0 |
| last = None |
| for i in range(attempts): |
| try: |
| return load_fn() |
| except Exception as e: |
| last = e |
| if _is_hub_rate_limit(e) and i < attempts - 1: |
| print( |
| f'Hugging Face rate limit (429); waiting {delay:.0f}s before retry ' |
| f'({i + 1}/{attempts}). Use HF_TOKEN in .env for higher limits.' |
| ) |
| time.sleep(delay) |
| delay = min(delay * 1.5, 120.0) |
| continue |
| raise |
| raise last |
|
|
|
|
| _ensure_hf_auth() |
|
|
|
|
| def model_max_length(model) -> int: |
| """Max positions for truncation (e.g. GPT-2 n_positions=1024). Avoids CUDA gather errors on long texts.""" |
| cfg = model.config |
| n = getattr(cfg, "n_positions", None) |
| if n is not None: |
| return int(n) |
| n = getattr(cfg, "max_position_embeddings", None) |
| if n is not None and int(n) < 1_000_000: |
| return int(n) |
| return 1024 |
|
|
|
|
| def from_pretrained(cls, model_name, kwargs, cache_dir): |
| |
| local_path = os.path.join(cache_dir, 'local.' + model_name.replace("/", "_")) |
| if os.path.exists(local_path): |
| return cls.from_pretrained(local_path, **kwargs) |
|
|
| def _load(): |
| return cls.from_pretrained(model_name, **kwargs, cache_dir=cache_dir) |
|
|
| return _with_hub_retries(_load) |
|
|
| |
| model_fullnames = { 'gpt2': 'gpt2', |
| 'gpt2-xl': 'gpt2-xl', |
| 'opt-2.7b': 'facebook/opt-2.7b', |
| 'gpt-neo-2.7B': 'EleutherAI/gpt-neo-2.7B', |
| 'gpt-j-6B': 'EleutherAI/gpt-j-6B', |
| 'gpt-neox-20b': 'EleutherAI/gpt-neox-20b', |
| 'mgpt': 'sberbank-ai/mGPT', |
| 'pubmedgpt': 'stanford-crfm/pubmedgpt', |
| 'mt5-xl': 'google/mt5-xl', |
| 'qwen-7b': 'Qwen/Qwen2.5-7B', |
| 'qwen-7b-instruct': 'Qwen/Qwen2.5-7B-Instruct', |
| 'mistralai-7b': 'mistralai/Mistral-7B-v0.1', |
| 'mistralai-7b-instruct': 'mistralai/Mistral-7B-Instruct-v0.3', |
| 'llama3-8b': 'meta-llama/Meta-Llama-3-8B', |
| 'llama3-8b-instruct': 'meta-llama/Meta-Llama-3-8B-Instruct', |
| 'falcon-7b': 'tiiuae/falcon-7b', |
| 'falcon-7b-instruct': 'tiiuae/falcon-7b-instruct', |
| 'llama2-13b': 'TheBloke/Llama-2-13B-fp16', |
| 'llama2-13b-chat': 'meta-llama/Llama-2-13b-chat-hf', |
| 'gemma-9b': 'google/gemma-2-9b', |
| 'gemma-9b-instruct': 'google/gemma-2-9b-it', |
| 'bloom-7b1': 'bigscience/bloom-7b1', |
| 'opt-13b': 'facebook/opt-13b', |
| 'pythia-12b': 'EleutherAI/pythia-12b', |
| } |
| float16_models = ['gpt-neo-2.7B', 'gpt-j-6B', 'gpt-neox-20b', 'falcon-7b', 'falcon-7b-instruct', 'qwen-7b', 'qwen-7b-instruct', 'mistralai-7b', 'mistralai-7b-instruct', 'llama3-8b', 'llama3-8b-instruct', 'gemma-9b', 'gemma-9b-instruct', 'llama2-13b', 'bloom-7b1', 'opt-13b', 'pythia-12b', 'llama2-13b-chat'] |
|
|
| def get_model_fullname(model_name): |
| return model_fullnames[model_name] if model_name in model_fullnames else model_name |
|
|
| def load_model(model_name, device, cache_dir, torch_dtype=None): |
| model_fullname = get_model_fullname(model_name) |
| print(f'Loading model {model_fullname}...') |
| model_kwargs = {} |
| if model_name in float16_models: |
| model_kwargs.update(dict(torch_dtype=torch.float16)) |
| if 'gpt-j' in model_name: |
| model_kwargs.update(dict(revision='float16')) |
| if torch_dtype is not None: |
| model_kwargs.update(dict(torch_dtype=torch_dtype)) |
| model = from_pretrained(AutoModelForCausalLM, model_fullname, model_kwargs, cache_dir) |
| print('Moving model to GPU...', end='', flush=True) |
| start = time.time() |
| model.to(device) |
| print(f'DONE ({time.time() - start:.2f}s)') |
| return model |
|
|
| def load_tokenizer(model_name, cache_dir): |
| model_fullname = get_model_fullname(model_name) |
| optional_tok_kwargs = {} |
| if "facebook/opt-" in model_fullname: |
| print("Using non-fast tokenizer for OPT") |
| optional_tok_kwargs['fast'] = False |
| optional_tok_kwargs['padding_side'] = 'right' |
| base_tokenizer = from_pretrained(AutoTokenizer, model_fullname, optional_tok_kwargs, cache_dir=cache_dir) |
| if base_tokenizer.pad_token_id is None: |
| base_tokenizer.pad_token_id = base_tokenizer.eos_token_id |
| if '13b' in model_fullname: |
| base_tokenizer.pad_token_id = 0 |
| return base_tokenizer |
|
|
|
|
| if __name__ == '__main__': |
| import argparse |
| parser = argparse.ArgumentParser() |
| parser.add_argument('--model_name', type=str, default="mistralai-7b-instruct") |
| parser.add_argument('--cache_dir', type=str, default="../cache") |
| args = parser.parse_args() |
|
|
| load_tokenizer(args.model_name, args.cache_dir) |
| load_model(args.model_name, 'cpu', args.cache_dir) |
|
|