split_turn / make_turn.py
cccxi's picture
Upload LoRA adapter folder
9b5771b verified
Raw
History Blame Contribute Delete
2.54 kB
import json
import copy
import argparse
import os
def split_dialogue_by_assistant(input_path, output_path):
if not os.path.exists(input_path):
print(f"❌ 錯誤:找不到輸入檔案 '{input_path}'")
return
split_results = []
original_count = 0
print(f"🚀 開始處理:{input_path}")
with open(input_path, 'r', encoding='utf-8') as f:
for line in f:
line = line.strip()
if not line:
continue
try:
entry = json.loads(line)
original_count += 1
messages = entry.pop("messages", [])
metadata = entry
current_history = []
assistant_turn_count = 0 # 用來計算這是第幾個 assistant 回覆
for msg in messages:
current_history.append(msg)
if msg.get("role") == "assistant":
assistant_turn_count += 1
# 複製 metadata
new_sample = copy.deepcopy(metadata)
# --- 修正 ID 邏輯 ---
if "id" in new_sample:
# 在原始 ID 後面加上序號,例如 "chat_001" 變成 "chat_001_turn_1"
new_sample["id"] = f"{new_sample['id']}_turn_{assistant_turn_count}"
# -------------------
new_sample["messages"] = copy.deepcopy(current_history)
split_results.append(new_sample)
except json.JSONDecodeError:
continue
with open(output_path, 'w', encoding='utf-8') as f:
for item in split_results:
f.write(json.dumps(item, ensure_ascii=False) + '\n')
print("-" * 30)
print(f"✅ 處理完成!")
print(f"📦 原始樣本數:{original_count}")
print(f"多輪切分後總樣本數:{len(split_results)}")
print(f"🆔 已自動為 ID 加上 _turn_N 後綴避免重複。")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--input", "-i", type=str, required=True)
parser.add_argument("--output", "-o", type=str, default="output_split.jsonl")
args = parser.parse_args()
split_dialogue_by_assistant(args.input, args.output)