from transformers import AutoModelForCausalLM, AutoTokenizer import torch import time import os def _repo_root(): return os.path.abspath(os.path.join(os.path.dirname(__file__), '..')) def _load_env_file(): path = os.path.join(_repo_root(), '.env') if not os.path.isfile(path): return with open(path, encoding='utf-8') as f: for line in f: line = line.strip() if not line or line.startswith('#') or '=' not in line: continue key, _, val = line.partition('=') key, val = key.strip(), val.strip().strip('"').strip("'") if key and key not in os.environ: os.environ[key] = val def _ensure_hf_auth(): _load_env_file() token = os.environ.get('HF_TOKEN') or os.environ.get('HUGGING_FACE_HUB_TOKEN') if not token: return try: from huggingface_hub import login login(token=token, add_to_git_credential=False) except Exception as e: print(f'Warning: Hugging Face login failed ({e})') def _is_hub_rate_limit(err): s = str(err).lower() return '429' in str(err) or 'too many requests' in s def _with_hub_retries(load_fn, attempts=6): delay = 15.0 last = None for i in range(attempts): try: return load_fn() except Exception as e: last = e if _is_hub_rate_limit(e) and i < attempts - 1: print( f'Hugging Face rate limit (429); waiting {delay:.0f}s before retry ' f'({i + 1}/{attempts}). Use HF_TOKEN in .env for higher limits.' ) time.sleep(delay) delay = min(delay * 1.5, 120.0) continue raise raise last _ensure_hf_auth() def model_max_length(model) -> int: """Max positions for truncation (e.g. GPT-2 n_positions=1024). Avoids CUDA gather errors on long texts.""" cfg = model.config n = getattr(cfg, "n_positions", None) if n is not None: return int(n) n = getattr(cfg, "max_position_embeddings", None) if n is not None and int(n) < 1_000_000: return int(n) return 1024 def from_pretrained(cls, model_name, kwargs, cache_dir): # use local model if it exists local_path = os.path.join(cache_dir, 'local.' + model_name.replace("/", "_")) if os.path.exists(local_path): return cls.from_pretrained(local_path, **kwargs) def _load(): return cls.from_pretrained(model_name, **kwargs, cache_dir=cache_dir) return _with_hub_retries(_load) # predefined models model_fullnames = { 'gpt2': 'gpt2', 'gpt2-xl': 'gpt2-xl', 'opt-2.7b': 'facebook/opt-2.7b', 'gpt-neo-2.7B': 'EleutherAI/gpt-neo-2.7B', 'gpt-j-6B': 'EleutherAI/gpt-j-6B', 'gpt-neox-20b': 'EleutherAI/gpt-neox-20b', 'mgpt': 'sberbank-ai/mGPT', 'pubmedgpt': 'stanford-crfm/pubmedgpt', 'mt5-xl': 'google/mt5-xl', 'qwen-7b': 'Qwen/Qwen2.5-7B', 'qwen-7b-instruct': 'Qwen/Qwen2.5-7B-Instruct', 'mistralai-7b': 'mistralai/Mistral-7B-v0.1', 'mistralai-7b-instruct': 'mistralai/Mistral-7B-Instruct-v0.3', 'llama3-8b': 'meta-llama/Meta-Llama-3-8B', 'llama3-8b-instruct': 'meta-llama/Meta-Llama-3-8B-Instruct', 'falcon-7b': 'tiiuae/falcon-7b', 'falcon-7b-instruct': 'tiiuae/falcon-7b-instruct', 'llama2-13b': 'TheBloke/Llama-2-13B-fp16', 'llama2-13b-chat': 'meta-llama/Llama-2-13b-chat-hf', 'gemma-9b': 'google/gemma-2-9b', 'gemma-9b-instruct': 'google/gemma-2-9b-it', 'bloom-7b1': 'bigscience/bloom-7b1', 'opt-13b': 'facebook/opt-13b', 'pythia-12b': 'EleutherAI/pythia-12b', } float16_models = ['gpt-neo-2.7B', 'gpt-j-6B', 'gpt-neox-20b', 'falcon-7b', 'falcon-7b-instruct', 'qwen-7b', 'qwen-7b-instruct', 'mistralai-7b', 'mistralai-7b-instruct', 'llama3-8b', 'llama3-8b-instruct', 'gemma-9b', 'gemma-9b-instruct', 'llama2-13b', 'bloom-7b1', 'opt-13b', 'pythia-12b', 'llama2-13b-chat'] def get_model_fullname(model_name): return model_fullnames[model_name] if model_name in model_fullnames else model_name def load_model(model_name, device, cache_dir, torch_dtype=None): model_fullname = get_model_fullname(model_name) print(f'Loading model {model_fullname}...') model_kwargs = {} if model_name in float16_models: model_kwargs.update(dict(torch_dtype=torch.float16)) if 'gpt-j' in model_name: model_kwargs.update(dict(revision='float16')) if torch_dtype is not None: model_kwargs.update(dict(torch_dtype=torch_dtype)) model = from_pretrained(AutoModelForCausalLM, model_fullname, model_kwargs, cache_dir) print('Moving model to GPU...', end='', flush=True) start = time.time() model.to(device) print(f'DONE ({time.time() - start:.2f}s)') return model def load_tokenizer(model_name, cache_dir): model_fullname = get_model_fullname(model_name) optional_tok_kwargs = {} if "facebook/opt-" in model_fullname: print("Using non-fast tokenizer for OPT") optional_tok_kwargs['fast'] = False optional_tok_kwargs['padding_side'] = 'right' base_tokenizer = from_pretrained(AutoTokenizer, model_fullname, optional_tok_kwargs, cache_dir=cache_dir) if base_tokenizer.pad_token_id is None: base_tokenizer.pad_token_id = base_tokenizer.eos_token_id if '13b' in model_fullname: base_tokenizer.pad_token_id = 0 return base_tokenizer if __name__ == '__main__': import argparse parser = argparse.ArgumentParser() parser.add_argument('--model_name', type=str, default="mistralai-7b-instruct") parser.add_argument('--cache_dir', type=str, default="../cache") args = parser.parse_args() load_tokenizer(args.model_name, args.cache_dir) load_model(args.model_name, 'cpu', args.cache_dir)