from __future__ import annotations import importlib.util from typing import Any import torch def load_model_and_tokenizer( model_name: str, trust_remote_code: bool = False, cache_dir: str | None = None, load_in_4bit: bool = False, torch_dtype: str = "bfloat16", ): del load_in_4bit dtype = torch.bfloat16 if torch_dtype == "bfloat16" else torch.float16 # Apply Liger kernels before constructing Llama models. if "llama" in model_name.lower(): try: from liger_kernel.transformers import apply_liger_kernel_to_llama except Exception as exc: raise RuntimeError( "Failed to import Liger kernel patcher for Llama. " "Install liger-kernel in the runtime environment." ) from exc apply_liger_kernel_to_llama() attn_impl = "sdpa" if importlib.util.find_spec("flash_attn") is not None: try: __import__("flash_attn") attn_impl = "flash_attention_2" except Exception: attn_impl = "sdpa" from transformers import AutoModelForCausalLM, AutoTokenizer tokenizer = AutoTokenizer.from_pretrained( model_name, trust_remote_code=trust_remote_code, cache_dir=cache_dir, ) model = AutoModelForCausalLM.from_pretrained( model_name, trust_remote_code=trust_remote_code, cache_dir=cache_dir, dtype=dtype, attn_implementation=attn_impl, ) return model, tokenizer, "transformers" def generation_kwargs(cfg: Any) -> dict[str, Any]: return { "max_prompt_length": cfg.max_prompt_length, "max_completion_length": cfg.max_completion_length, "num_generations": cfg.num_generations, "temperature": cfg.temperature, "top_p": cfg.top_p, }