| from __future__ import annotations |
|
|
| import importlib.util |
| from typing import Any |
|
|
| import torch |
|
|
|
|
| def load_model_and_tokenizer( |
| model_name: str, |
| trust_remote_code: bool = False, |
| cache_dir: str | None = None, |
| load_in_4bit: bool = False, |
| torch_dtype: str = "bfloat16", |
| ): |
| del load_in_4bit |
| dtype = torch.bfloat16 if torch_dtype == "bfloat16" else torch.float16 |
|
|
| |
| if "llama" in model_name.lower(): |
| try: |
| from liger_kernel.transformers import apply_liger_kernel_to_llama |
| except Exception as exc: |
| raise RuntimeError( |
| "Failed to import Liger kernel patcher for Llama. " |
| "Install liger-kernel in the runtime environment." |
| ) from exc |
| apply_liger_kernel_to_llama() |
|
|
| attn_impl = "sdpa" |
| if importlib.util.find_spec("flash_attn") is not None: |
| try: |
| __import__("flash_attn") |
| attn_impl = "flash_attention_2" |
| except Exception: |
| attn_impl = "sdpa" |
|
|
| from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
| tokenizer = AutoTokenizer.from_pretrained( |
| model_name, |
| trust_remote_code=trust_remote_code, |
| cache_dir=cache_dir, |
| ) |
| model = AutoModelForCausalLM.from_pretrained( |
| model_name, |
| trust_remote_code=trust_remote_code, |
| cache_dir=cache_dir, |
| dtype=dtype, |
| attn_implementation=attn_impl, |
| ) |
| return model, tokenizer, "transformers" |
|
|
|
|
| def generation_kwargs(cfg: Any) -> dict[str, Any]: |
| return { |
| "max_prompt_length": cfg.max_prompt_length, |
| "max_completion_length": cfg.max_completion_length, |
| "num_generations": cfg.num_generations, |
| "temperature": cfg.temperature, |
| "top_p": cfg.top_p, |
| } |
|
|