neuralese_temp / src /hackable /backends.py
psidharth567's picture
Export neuralese codebase (cache and .env excluded).
dbc69f3
from __future__ import annotations
import importlib.util
from typing import Any
import torch
def load_model_and_tokenizer(
model_name: str,
trust_remote_code: bool = False,
cache_dir: str | None = None,
load_in_4bit: bool = False,
torch_dtype: str = "bfloat16",
):
del load_in_4bit
dtype = torch.bfloat16 if torch_dtype == "bfloat16" else torch.float16
# Apply Liger kernels before constructing Llama models.
if "llama" in model_name.lower():
try:
from liger_kernel.transformers import apply_liger_kernel_to_llama
except Exception as exc:
raise RuntimeError(
"Failed to import Liger kernel patcher for Llama. "
"Install liger-kernel in the runtime environment."
) from exc
apply_liger_kernel_to_llama()
attn_impl = "sdpa"
if importlib.util.find_spec("flash_attn") is not None:
try:
__import__("flash_attn")
attn_impl = "flash_attention_2"
except Exception:
attn_impl = "sdpa"
from transformers import AutoModelForCausalLM, AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(
model_name,
trust_remote_code=trust_remote_code,
cache_dir=cache_dir,
)
model = AutoModelForCausalLM.from_pretrained(
model_name,
trust_remote_code=trust_remote_code,
cache_dir=cache_dir,
dtype=dtype,
attn_implementation=attn_impl,
)
return model, tokenizer, "transformers"
def generation_kwargs(cfg: Any) -> dict[str, Any]:
return {
"max_prompt_length": cfg.max_prompt_length,
"max_completion_length": cfg.max_completion_length,
"num_generations": cfg.num_generations,
"temperature": cfg.temperature,
"top_p": cfg.top_p,
}