File size: 1,838 Bytes
dbc69f3 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 | from __future__ import annotations
import importlib.util
from typing import Any
import torch
def load_model_and_tokenizer(
model_name: str,
trust_remote_code: bool = False,
cache_dir: str | None = None,
load_in_4bit: bool = False,
torch_dtype: str = "bfloat16",
):
del load_in_4bit
dtype = torch.bfloat16 if torch_dtype == "bfloat16" else torch.float16
# Apply Liger kernels before constructing Llama models.
if "llama" in model_name.lower():
try:
from liger_kernel.transformers import apply_liger_kernel_to_llama
except Exception as exc:
raise RuntimeError(
"Failed to import Liger kernel patcher for Llama. "
"Install liger-kernel in the runtime environment."
) from exc
apply_liger_kernel_to_llama()
attn_impl = "sdpa"
if importlib.util.find_spec("flash_attn") is not None:
try:
__import__("flash_attn")
attn_impl = "flash_attention_2"
except Exception:
attn_impl = "sdpa"
from transformers import AutoModelForCausalLM, AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(
model_name,
trust_remote_code=trust_remote_code,
cache_dir=cache_dir,
)
model = AutoModelForCausalLM.from_pretrained(
model_name,
trust_remote_code=trust_remote_code,
cache_dir=cache_dir,
dtype=dtype,
attn_implementation=attn_impl,
)
return model, tokenizer, "transformers"
def generation_kwargs(cfg: Any) -> dict[str, Any]:
return {
"max_prompt_length": cfg.max_prompt_length,
"max_completion_length": cfg.max_completion_length,
"num_generations": cfg.num_generations,
"temperature": cfg.temperature,
"top_p": cfg.top_p,
}
|