| | from transformers import AutoTokenizer, AutoModelForCausalLM |
| | from datasets import Dataset, load_dataset |
| | import json |
| | from tqdm import tqdm |
| | import copy |
| |
|
| |
|
| | def read_raw_data(file_path): |
| | raw_data = [] |
| | with open(file_path, 'r') as f: |
| | json_list = list(f) |
| | for json_str in json_list: |
| | raw_data.append(json.loads(json_str)) |
| | return raw_data |
| |
|
| | def split_one_chat_as_two(raw_dataset): |
| | chat_mul2 = [] |
| | for item in raw_dataset: |
| | |
| | chat_list = item["messages"] |
| | assert len(chat_list) == 5, "length should be 5" |
| | chat_1 = chat_list[:3] |
| | chat_2 = chat_list |
| | chat_mul2.append({"messages": chat_1}) |
| | chat_mul2.append({"messages": chat_2}) |
| | assert len(chat_mul2) == 2 * len(raw_dataset) |
| | return chat_mul2 |
| |
|
| |
|
| | def format_dataset(raw_dataset, fmt_tokenizer): |
| | message_list_infer = [] |
| | message_list_label = [] |
| | message_list_trct = [] |
| | for entry in raw_dataset: |
| | message_list_infer.append(entry["messages"][:-1]) |
| | message_list_label.append(entry["messages"][-1]) |
| | truncate_chat = [entry["messages"][0], entry["messages"][3]] |
| | message_list_trct.append(truncate_chat) |
| | dataset = Dataset.from_dict({"complete_chat": message_list_infer, "truncate_chat": message_list_trct, "label": message_list_label}) |
| | dataset = dataset.map(lambda x: {"formatted_chat": fmt_tokenizer.apply_chat_template(x["truncate_chat"], tokenize=False, add_generation_prompt=True)}) |
| | return dataset |
| |
|
| |
|
| | if __name__ == "__main__": |
| | finetuned_path = "checkpoint" |
| | test_dataset_path = "YOUR_PATH_TO_EMOTIONBENCH" |
| | llama3_fmt_tokenizer_path = "YOUR_PATH_TO_META_LLAMA-3.1-8B" |
| |
|
| | llama3_ft_model = AutoModelForCausalLM.from_pretrained(finetuned_path, device_map='auto') |
| | llama3_tokenizer = AutoTokenizer.from_pretrained(finetuned_path) |
| | llama3_format_tokenizer = AutoTokenizer.from_pretrained(llama3_fmt_tokenizer_path) |
| |
|
| | raw_data = read_raw_data(file_path=test_dataset_path) |
| | |
| | eval_data_formatted = format_dataset(raw_data, llama3_format_tokenizer) |
| | print(eval_data_formatted) |
| |
|
| | |
| |
|
| | ret_list = [] |
| | for sample in tqdm(eval_data_formatted["formatted_chat"][:], desc='Infering answers: '): |
| | inputs = llama3_tokenizer(sample, return_tensors="pt") |
| | inputs = inputs.to("cuda") |
| | |
| | |
| | outputs = llama3_ft_model.generate(**inputs, max_new_tokens=256)[0] |
| | decoded_text = llama3_tokenizer.decode(outputs) |
| | gen_text = decoded_text.split("<|start_header_id|>assistant<|end_header_id|>")[1].strip() |
| | ret_list.append(gen_text) |
| | |
| | with open("test_answers_with_context_trct.json", "w") as f: |
| | json.dump(ret_list, f, indent=2) |
| | eval_data_formatted.to_json("test_data_trct.jsonl") |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |