File size: 1,838 Bytes
dbc69f3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
from __future__ import annotations

import importlib.util
from typing import Any

import torch


def load_model_and_tokenizer(
    model_name: str,
    trust_remote_code: bool = False,
    cache_dir: str | None = None,
    load_in_4bit: bool = False,
    torch_dtype: str = "bfloat16",
):
    del load_in_4bit
    dtype = torch.bfloat16 if torch_dtype == "bfloat16" else torch.float16

    # Apply Liger kernels before constructing Llama models.
    if "llama" in model_name.lower():
        try:
            from liger_kernel.transformers import apply_liger_kernel_to_llama
        except Exception as exc:
            raise RuntimeError(
                "Failed to import Liger kernel patcher for Llama. "
                "Install liger-kernel in the runtime environment."
            ) from exc
        apply_liger_kernel_to_llama()

    attn_impl = "sdpa"
    if importlib.util.find_spec("flash_attn") is not None:
        try:
            __import__("flash_attn")
            attn_impl = "flash_attention_2"
        except Exception:
            attn_impl = "sdpa"

    from transformers import AutoModelForCausalLM, AutoTokenizer

    tokenizer = AutoTokenizer.from_pretrained(
        model_name,
        trust_remote_code=trust_remote_code,
        cache_dir=cache_dir,
    )
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        trust_remote_code=trust_remote_code,
        cache_dir=cache_dir,
        dtype=dtype,
        attn_implementation=attn_impl,
    )
    return model, tokenizer, "transformers"


def generation_kwargs(cfg: Any) -> dict[str, Any]:
    return {
        "max_prompt_length": cfg.max_prompt_length,
        "max_completion_length": cfg.max_completion_length,
        "num_generations": cfg.num_generations,
        "temperature": cfg.temperature,
        "top_p": cfg.top_p,
    }