Update README.md
Browse files
README.md
CHANGED
|
@@ -106,9 +106,10 @@ Dialogue History:
|
|
| 106 |
**Example Code**
|
| 107 |
|
| 108 |
```python
|
| 109 |
-
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 110 |
from prompts import overall_evaluation_gen_prompt ## same prompt on "Instruction for TD-Llama-General"
|
| 111 |
-
import json,re
|
|
|
|
| 112 |
|
| 113 |
def normalize_action(text):
|
| 114 |
lower_text = text.lower()
|
|
@@ -182,8 +183,18 @@ def overall_form(dialogue_idx,full_dialogue,tokenizer_inf,is_val=False):
|
|
| 182 |
return label
|
| 183 |
|
| 184 |
device = "cuda:0"
|
| 185 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 186 |
tokenizer = AutoTokenizer.from_pretrained("HOLILAB/td-llama-general")
|
|
|
|
| 187 |
|
| 188 |
with open("test_dialogue_overall_obs.json",'r') as f: # https://drive.google.com/drive/folders/17IASNvCcJRkHlg2tzMPRlpPbo_MlHs94?hl=ko
|
| 189 |
dial_list = json.load(f)
|
|
|
|
| 106 |
**Example Code**
|
| 107 |
|
| 108 |
```python
|
| 109 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
|
| 110 |
from prompts import overall_evaluation_gen_prompt ## same prompt on "Instruction for TD-Llama-General"
|
| 111 |
+
import json,re,torch
|
| 112 |
+
from peft import PeftModel, PeftConfig
|
| 113 |
|
| 114 |
def normalize_action(text):
|
| 115 |
lower_text = text.lower()
|
|
|
|
| 183 |
return label
|
| 184 |
|
| 185 |
device = "cuda:0"
|
| 186 |
+
base_model = AutoModelForCausalLM.from_pretrained(
|
| 187 |
+
"meta-llama/Llama-3.1-8B-Instruct",
|
| 188 |
+
quantization_config=BitsAndBytesConfig(
|
| 189 |
+
load_in_4bit=True,
|
| 190 |
+
bnb_4bit_quant_type="nf4",
|
| 191 |
+
bnb_4bit_compute_dtype=torch.bfloat16,
|
| 192 |
+
bnb_4bit_use_double_quant=True,
|
| 193 |
+
),
|
| 194 |
+
device_map={"": device},
|
| 195 |
+
)
|
| 196 |
tokenizer = AutoTokenizer.from_pretrained("HOLILAB/td-llama-general")
|
| 197 |
+
model = PeftModel.from_pretrained(base_model, "HOLILAB/td-llama-general")
|
| 198 |
|
| 199 |
with open("test_dialogue_overall_obs.json",'r') as f: # https://drive.google.com/drive/folders/17IASNvCcJRkHlg2tzMPRlpPbo_MlHs94?hl=ko
|
| 200 |
dial_list = json.load(f)
|