Holi-jhshim commited on
Commit
6b11116
·
verified ·
1 Parent(s): b94bd29

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +14 -3
README.md CHANGED
@@ -106,9 +106,10 @@ Dialogue History:
106
  **Example Code**
107
 
108
  ```python
109
- from transformers import AutoModelForCausalLM, AutoTokenizer
110
  from prompts import overall_evaluation_gen_prompt ## same prompt on "Instruction for TD-Llama-General"
111
- import json,re
 
112
 
113
  def normalize_action(text):
114
  lower_text = text.lower()
@@ -182,8 +183,18 @@ def overall_form(dialogue_idx,full_dialogue,tokenizer_inf,is_val=False):
182
  return label
183
 
184
  device = "cuda:0"
185
- model = AutoModelForCausalLM.from_pretrained("HOLILAB/td-llama-general",device_map={"": device})
 
 
 
 
 
 
 
 
 
186
  tokenizer = AutoTokenizer.from_pretrained("HOLILAB/td-llama-general")
 
187
 
188
  with open("test_dialogue_overall_obs.json",'r') as f: # https://drive.google.com/drive/folders/17IASNvCcJRkHlg2tzMPRlpPbo_MlHs94?hl=ko
189
  dial_list = json.load(f)
 
106
  **Example Code**
107
 
108
  ```python
109
+ from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
110
  from prompts import overall_evaluation_gen_prompt ## same prompt on "Instruction for TD-Llama-General"
111
+ import json,re,torch
112
+ from peft import PeftModel, PeftConfig
113
 
114
  def normalize_action(text):
115
  lower_text = text.lower()
 
183
  return label
184
 
185
  device = "cuda:0"
186
+ base_model = AutoModelForCausalLM.from_pretrained(
187
+ "meta-llama/Llama-3.1-8B-Instruct",
188
+ quantization_config=BitsAndBytesConfig(
189
+ load_in_4bit=True,
190
+ bnb_4bit_quant_type="nf4",
191
+ bnb_4bit_compute_dtype=torch.bfloat16,
192
+ bnb_4bit_use_double_quant=True,
193
+ ),
194
+ device_map={"": device},
195
+ )
196
  tokenizer = AutoTokenizer.from_pretrained("HOLILAB/td-llama-general")
197
+ model = PeftModel.from_pretrained(base_model, "HOLILAB/td-llama-general")
198
 
199
  with open("test_dialogue_overall_obs.json",'r') as f: # https://drive.google.com/drive/folders/17IASNvCcJRkHlg2tzMPRlpPbo_MlHs94?hl=ko
200
  dial_list = json.load(f)