File size: 2,421 Bytes
e77cf8d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import math
import json
import random
from collections import defaultdict

from regex import T

raw_counts = {
    'ToxiCN': 184,
    'COLD': 635,
    'emoji': 49,
    'homo': 39,
    'Cdial-Bias': 262,
    'SWSR': 68,
    'SCCD_single_turn': 285,
    'SCCD_multi_turn': 34
}

total = sum(raw_counts.values())  # 1556
target_total = 1000

# 初步按比例计算
proportional_samples = {k: round(v / total * target_total) for k, v in raw_counts.items()}

# 调整总数为 1000(防止加总后不等于目标值)
adjusted_samples = proportional_samples.copy()
current_total = sum(adjusted_samples.values())
diff = target_total - current_total

# 修正:按比例排序后,逐个加/减
if diff != 0:
    sorted_keys = sorted(raw_counts, key=lambda k: raw_counts[k], reverse=(diff > 0))
    for i in range(abs(diff)):
        key = sorted_keys[i % len(sorted_keys)]
        adjusted_samples[key] += 1 if diff > 0 else -1

# 输出结果
print(f"最终分布(总计 {sum(adjusted_samples.values())}):")
for k, v in adjusted_samples.items():
    print(f"{k}: {v}")


###### 开始划分

# 加载原始数据
INPUT_FILE = "/mnt/data/users/liamding/data/sft_zh_tox/data/Style-datasets-idx.json"
# 加载原始数据
with open(INPUT_FILE, 'r', encoding='utf-8') as f:
    data = json.load(f)

# 按 dataset 分组
grouped = defaultdict(list)
for item in data:
    grouped[item["dataset"]].append(item)

# 按分布采样 1000 条数据
sampled_data = []
sampled_ids = set()
for dataset, count in adjusted_samples.items():
    samples = random.sample(grouped[dataset], count)
    sampled_data.extend(samples)
    sampled_ids.update(id(sample) for sample in samples)  # 用内存 id 做唯一标识

# 剩余的为 test 数据
test_data = [item for item in data if id(item) not in sampled_ids]

# 检查数量正确
assert len(sampled_data) == 1000, f"采样数不是 1000,而是 {len(sampled_data)}"
assert len(test_data) + len(sampled_data) == len(data), "总数量不一致"

# 保存
TRAIN_OUTPUT = f"train_1000.json"
TEST_OUTPUT = "test_556.json"
with open(TRAIN_OUTPUT, 'w', encoding='utf-8') as f:
    json.dump(sampled_data, f, ensure_ascii=False, indent=4)

with open(TEST_OUTPUT, 'w', encoding='utf-8') as f:
    json.dump(test_data, f, ensure_ascii=False, indent=4)

print(f"训练集(1000 条)已保存至: {TRAIN_OUTPUT}")
print(f"测试集({len(test_data)} 条)已保存至: {TEST_OUTPUT}")