ArthurLin commited on
Commit
1d1adda
·
verified ·
1 Parent(s): a2cd8f6

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +23 -1
model.py CHANGED
@@ -5,8 +5,28 @@ import os
5
  hf_token = os.getenv("LLM_token")
6
  os.environ["HUGGINGFACE_HUB_TOKEN"] = hf_token
7
 
8
- bnb_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_quant_type="nf4")
 
 
 
 
 
9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  def load_model(model_path="meta-llama/Meta-Llama-3-8B-Instruct"):
11
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
12
 
@@ -19,3 +39,5 @@ def load_model(model_path="meta-llama/Meta-Llama-3-8B-Instruct"):
19
  token=hf_token
20
  )
21
  return pipe
 
 
 
5
  hf_token = os.getenv("LLM_token")
6
  os.environ["HUGGINGFACE_HUB_TOKEN"] = hf_token
7
 
8
+ bnb_config = BitsAndBytesConfig(
9
+ load_in_4bit=True,
10
+ bnb_4bit_quant_type="nf4",
11
+ bnb_4bit_use_double_quant=True,
12
+ llm_int8_skip_modules=None
13
+ )
14
 
15
+ def load_model(model_path="meta-llama/Meta-Llama-3-8B-Instruct"):
16
+ # 不要使用 torch.device 來傳入 pipeline,改用 device_map="auto"
17
+ pipe = pipeline(
18
+ "text-generation",
19
+ model=model_path,
20
+ model_kwargs={
21
+ "quantization_config": bnb_config,
22
+ "device_map": "auto",
23
+ "torch_dtype": torch.float16
24
+ },
25
+ token=hf_token
26
+ )
27
+ return pipe
28
+
29
+ '''
30
  def load_model(model_path="meta-llama/Meta-Llama-3-8B-Instruct"):
31
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
32
 
 
39
  token=hf_token
40
  )
41
  return pipe
42
+ '''
43
+