import math import json import random from collections import defaultdict from regex import T raw_counts = { 'ToxiCN': 184, 'COLD': 635, 'emoji': 49, 'homo': 39, 'Cdial-Bias': 262, 'SWSR': 68, 'SCCD_single_turn': 285, 'SCCD_multi_turn': 34 } total = sum(raw_counts.values()) # 1556 target_total = 1000 # 初步按比例计算 proportional_samples = {k: round(v / total * target_total) for k, v in raw_counts.items()} # 调整总数为 1000(防止加总后不等于目标值) adjusted_samples = proportional_samples.copy() current_total = sum(adjusted_samples.values()) diff = target_total - current_total # 修正:按比例排序后,逐个加/减 if diff != 0: sorted_keys = sorted(raw_counts, key=lambda k: raw_counts[k], reverse=(diff > 0)) for i in range(abs(diff)): key = sorted_keys[i % len(sorted_keys)] adjusted_samples[key] += 1 if diff > 0 else -1 # 输出结果 print(f"最终分布(总计 {sum(adjusted_samples.values())}):") for k, v in adjusted_samples.items(): print(f"{k}: {v}") ###### 开始划分 # 加载原始数据 INPUT_FILE = "/mnt/data/users/liamding/data/sft_zh_tox/data/Style-datasets-idx.json" # 加载原始数据 with open(INPUT_FILE, 'r', encoding='utf-8') as f: data = json.load(f) # 按 dataset 分组 grouped = defaultdict(list) for item in data: grouped[item["dataset"]].append(item) # 按分布采样 1000 条数据 sampled_data = [] sampled_ids = set() for dataset, count in adjusted_samples.items(): samples = random.sample(grouped[dataset], count) sampled_data.extend(samples) sampled_ids.update(id(sample) for sample in samples) # 用内存 id 做唯一标识 # 剩余的为 test 数据 test_data = [item for item in data if id(item) not in sampled_ids] # 检查数量正确 assert len(sampled_data) == 1000, f"采样数不是 1000,而是 {len(sampled_data)}" assert len(test_data) + len(sampled_data) == len(data), "总数量不一致" # 保存 TRAIN_OUTPUT = f"train_1000.json" TEST_OUTPUT = "test_556.json" with open(TRAIN_OUTPUT, 'w', encoding='utf-8') as f: json.dump(sampled_data, f, ensure_ascii=False, indent=4) with open(TEST_OUTPUT, 'w', encoding='utf-8') as f: json.dump(test_data, f, ensure_ascii=False, indent=4) print(f"训练集(1000 条)已保存至: {TRAIN_OUTPUT}") print(f"测试集({len(test_data)} 条)已保存至: {TEST_OUTPUT}")