File size: 735 Bytes
a9a45af
f5dd377
798f90c
 
 
 
a9a45af
1d1adda
 
 
 
 
 
f5dd377
1d1adda
a9a45af
 
 
 
 
 
 
19a3384
a9a45af
798f90c
a9a45af
 
19a3384
1d1adda
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
import torch
from transformers import pipeline, BitsAndBytesConfig
import os 

hf_token = os.getenv("LLM_token")
os.environ["HUGGINGFACE_HUB_TOKEN"] = hf_token

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_use_double_quant=True,
    llm_int8_skip_modules=None
)


def load_model(model_path="meta-llama/Meta-Llama-3-8B-Instruct"):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    pipe = pipeline(
        "text-generation",
        model=model_path,
        model_kwargs={"torch_dtype": torch.float16} if torch.cuda.is_available() else {},
#        quantization_config=bnb_config,
        device=device,
        token=hf_token
    )
    return pipe