ArthurLin's picture
Update model.py
19a3384 verified
raw
history blame contribute delete
735 Bytes
import torch
from transformers import pipeline, BitsAndBytesConfig
import os
hf_token = os.getenv("LLM_token")
os.environ["HUGGINGFACE_HUB_TOKEN"] = hf_token
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_use_double_quant=True,
llm_int8_skip_modules=None
)
def load_model(model_path="meta-llama/Meta-Llama-3-8B-Instruct"):
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
pipe = pipeline(
"text-generation",
model=model_path,
model_kwargs={"torch_dtype": torch.float16} if torch.cuda.is_available() else {},
# quantization_config=bnb_config,
device=device,
token=hf_token
)
return pipe