lora_ckp / lora_checkpoints /data /QiaoBan /generate_train_data.py
Ray121381's picture
1
e3e3f87
import json
from copy import deepcopy
raw_train_data = []
with open("./QiaoBan/data/child_chat_data.json", "r", encoding="utf-8") as fr:
raw_train_data = json.load(fr)
def construct_dialog_sample(dialog):
splited_data = dialog.split("</s>")
# assert len(splited_data) % 2 == 0
if len(splited_data) % 2 == 1:
splited_data.append("智能助手:嗯嗯。")
chat_data = []
history = []
for i in range(0, len(splited_data), 2):
user = splited_data[i].split(":" if ":" in splited_data[i] else ":")[-1]
assitant = splited_data[i+1].split(":" if ":" in splited_data[i+1] else ":")[-1]
chat_data.append({
"prompt": user,
"response": assitant,
"history": deepcopy(history)
})
history.append([user, assitant])
return chat_data
with open("chat_train_data.json", "w", encoding="utf-8") as fw:
for sample in raw_train_data:
dialog = sample["input"]
chat_data = construct_dialog_sample(dialog)
for data in chat_data:
fw.write(json.dumps(data, ensure_ascii=False))
fw.write("\n")