| import json |
| from copy import deepcopy |
|
|
| raw_train_data = [] |
| with open("./QiaoBan/data/child_chat_data.json", "r", encoding="utf-8") as fr: |
| raw_train_data = json.load(fr) |
|
|
| def construct_dialog_sample(dialog): |
| splited_data = dialog.split("</s>") |
| |
| if len(splited_data) % 2 == 1: |
| splited_data.append("智能助手:嗯嗯。") |
| chat_data = [] |
| history = [] |
| for i in range(0, len(splited_data), 2): |
| user = splited_data[i].split(":" if ":" in splited_data[i] else ":")[-1] |
| assitant = splited_data[i+1].split(":" if ":" in splited_data[i+1] else ":")[-1] |
| chat_data.append({ |
| "prompt": user, |
| "response": assitant, |
| "history": deepcopy(history) |
| }) |
| history.append([user, assitant]) |
| return chat_data |
|
|
| with open("chat_train_data.json", "w", encoding="utf-8") as fw: |
| for sample in raw_train_data: |
| dialog = sample["input"] |
| chat_data = construct_dialog_sample(dialog) |
| for data in chat_data: |
| fw.write(json.dumps(data, ensure_ascii=False)) |
| fw.write("\n") |