|
|
import math |
|
|
import json |
|
|
import random |
|
|
from collections import defaultdict |
|
|
|
|
|
from regex import T |
|
|
|
|
|
raw_counts = { |
|
|
'ToxiCN': 184, |
|
|
'COLD': 635, |
|
|
'emoji': 49, |
|
|
'homo': 39, |
|
|
'Cdial-Bias': 262, |
|
|
'SWSR': 68, |
|
|
'SCCD_single_turn': 285, |
|
|
'SCCD_multi_turn': 34 |
|
|
} |
|
|
|
|
|
total = sum(raw_counts.values()) |
|
|
target_total = 1000 |
|
|
|
|
|
|
|
|
proportional_samples = {k: round(v / total * target_total) for k, v in raw_counts.items()} |
|
|
|
|
|
|
|
|
adjusted_samples = proportional_samples.copy() |
|
|
current_total = sum(adjusted_samples.values()) |
|
|
diff = target_total - current_total |
|
|
|
|
|
|
|
|
if diff != 0: |
|
|
sorted_keys = sorted(raw_counts, key=lambda k: raw_counts[k], reverse=(diff > 0)) |
|
|
for i in range(abs(diff)): |
|
|
key = sorted_keys[i % len(sorted_keys)] |
|
|
adjusted_samples[key] += 1 if diff > 0 else -1 |
|
|
|
|
|
|
|
|
print(f"最终分布(总计 {sum(adjusted_samples.values())}):") |
|
|
for k, v in adjusted_samples.items(): |
|
|
print(f"{k}: {v}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
INPUT_FILE = "/mnt/data/users/liamding/data/sft_zh_tox/data/Style-datasets-idx.json" |
|
|
|
|
|
with open(INPUT_FILE, 'r', encoding='utf-8') as f: |
|
|
data = json.load(f) |
|
|
|
|
|
|
|
|
grouped = defaultdict(list) |
|
|
for item in data: |
|
|
grouped[item["dataset"]].append(item) |
|
|
|
|
|
|
|
|
sampled_data = [] |
|
|
sampled_ids = set() |
|
|
for dataset, count in adjusted_samples.items(): |
|
|
samples = random.sample(grouped[dataset], count) |
|
|
sampled_data.extend(samples) |
|
|
sampled_ids.update(id(sample) for sample in samples) |
|
|
|
|
|
|
|
|
test_data = [item for item in data if id(item) not in sampled_ids] |
|
|
|
|
|
|
|
|
assert len(sampled_data) == 1000, f"采样数不是 1000,而是 {len(sampled_data)}" |
|
|
assert len(test_data) + len(sampled_data) == len(data), "总数量不一致" |
|
|
|
|
|
|
|
|
TRAIN_OUTPUT = f"train_1000.json" |
|
|
TEST_OUTPUT = "test_556.json" |
|
|
with open(TRAIN_OUTPUT, 'w', encoding='utf-8') as f: |
|
|
json.dump(sampled_data, f, ensure_ascii=False, indent=4) |
|
|
|
|
|
with open(TEST_OUTPUT, 'w', encoding='utf-8') as f: |
|
|
json.dump(test_data, f, ensure_ascii=False, indent=4) |
|
|
|
|
|
print(f"训练集(1000 条)已保存至: {TRAIN_OUTPUT}") |
|
|
print(f"测试集({len(test_data)} 条)已保存至: {TEST_OUTPUT}") |