| import json
|
| import copy
|
| import argparse
|
| import os
|
|
|
| def split_dialogue_by_assistant(input_path, output_path):
|
| if not os.path.exists(input_path):
|
| print(f"❌ 錯誤:找不到輸入檔案 '{input_path}'")
|
| return
|
|
|
| split_results = []
|
| original_count = 0
|
|
|
| print(f"🚀 開始處理:{input_path}")
|
|
|
| with open(input_path, 'r', encoding='utf-8') as f:
|
| for line in f:
|
| line = line.strip()
|
| if not line:
|
| continue
|
|
|
| try:
|
| entry = json.loads(line)
|
| original_count += 1
|
|
|
| messages = entry.pop("messages", [])
|
| metadata = entry
|
|
|
| current_history = []
|
| assistant_turn_count = 0
|
|
|
| for msg in messages:
|
| current_history.append(msg)
|
|
|
| if msg.get("role") == "assistant":
|
| assistant_turn_count += 1
|
|
|
|
|
| new_sample = copy.deepcopy(metadata)
|
|
|
|
|
| if "id" in new_sample:
|
|
|
| new_sample["id"] = f"{new_sample['id']}_turn_{assistant_turn_count}"
|
|
|
|
|
| new_sample["messages"] = copy.deepcopy(current_history)
|
| split_results.append(new_sample)
|
|
|
| except json.JSONDecodeError:
|
| continue
|
|
|
| with open(output_path, 'w', encoding='utf-8') as f:
|
| for item in split_results:
|
| f.write(json.dumps(item, ensure_ascii=False) + '\n')
|
|
|
| print("-" * 30)
|
| print(f"✅ 處理完成!")
|
| print(f"📦 原始樣本數:{original_count}")
|
| print(f"多輪切分後總樣本數:{len(split_results)}")
|
| print(f"🆔 已自動為 ID 加上 _turn_N 後綴避免重複。")
|
|
|
| if __name__ == "__main__":
|
| parser = argparse.ArgumentParser()
|
| parser.add_argument("--input", "-i", type=str, required=True)
|
| parser.add_argument("--output", "-o", type=str, default="output_split.jsonl")
|
| args = parser.parse_args()
|
| split_dialogue_by_assistant(args.input, args.output) |