ArthurLin commited on
Commit
f5dd377
·
verified ·
1 Parent(s): 8ee592e

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +4 -1
model.py CHANGED
@@ -1,10 +1,12 @@
1
  import torch
2
- from transformers import pipeline
3
  import os
4
 
5
  hf_token = os.getenv("LLM_token")
6
  os.environ["HUGGINGFACE_HUB_TOKEN"] = hf_token
7
 
 
 
8
  def load_model(model_path="meta-llama/Meta-Llama-3-8B-Instruct"):
9
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
10
 
@@ -12,6 +14,7 @@ def load_model(model_path="meta-llama/Meta-Llama-3-8B-Instruct"):
12
  "text-generation",
13
  model=model_path,
14
  model_kwargs={"torch_dtype": torch.float16} if torch.cuda.is_available() else {},
 
15
  device=device,
16
  token=hf_token
17
  )
 
1
  import torch
2
+ from transformers import pipeline, BitsAndBytesConfig
3
  import os
4
 
5
  hf_token = os.getenv("LLM_token")
6
  os.environ["HUGGINGFACE_HUB_TOKEN"] = hf_token
7
 
8
+ bnb_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_quant_type="nf4")
9
+
10
  def load_model(model_path="meta-llama/Meta-Llama-3-8B-Instruct"):
11
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
12
 
 
14
  "text-generation",
15
  model=model_path,
16
  model_kwargs={"torch_dtype": torch.float16} if torch.cuda.is_available() else {},
17
+ quantization_config=bnb_config,
18
  device=device,
19
  token=hf_token
20
  )