File size: 4,905 Bytes
1c980b1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 |
import json
import sys
import random
from collections import defaultdict
def collect_dataset_info(file_path):
"""收集数据集信息,包括每个数据集的行号列表和首次出现顺序"""
dataset_lines = defaultdict(list)
order = []
seen = set()
with open(file_path, 'r') as f:
for line_num, line in enumerate(f, 1):
try:
data = json.loads(line.strip())
custom_id = data['custom_id']
dataset = custom_id.split('-')[0]
if dataset not in seen:
order.append(dataset)
seen.add(dataset)
dataset_lines[dataset].append(line_num)
except json.JSONDecodeError:
print(f"Error: Invalid JSON at line {line_num}", file=sys.stderr)
except KeyError:
print(f"Error: Missing 'custom_id' at line {line_num}", file=sys.stderr)
except IndexError:
print(f"Error: Invalid custom_id format at line {line_num}", file=sys.stderr)
return dataset_lines, order
def main():
if len(sys.argv) != 4:
print("Usage: python sample_datasets.py <input.jsonl> <output.jsonl> <N>")
sys.exit(1)
input_file = sys.argv[1]
output_file = sys.argv[2]
try:
N = int(sys.argv[3])
except ValueError:
print("Error: N must be an integer.")
sys.exit(1)
# 收集数据集信息
dataset_info, dataset_order = collect_dataset_info(input_file)
k = len(dataset_info)
if k == 0:
print("Error: No datasets found in the input file.")
sys.exit(1)
# 检查每个数据集是否有至少5个样本
for dataset, lines in dataset_info.items():
if len(lines) < 5:
print(f"Error: Dataset '{dataset}' has fewer than 5 samples.")
sys.exit(1)
total_samples = sum(len(lines) for lines in dataset_info.values())
min_samples = 5 * k
if N < min_samples or N > total_samples:
print(f"Error: N must be between {min_samples} and {total_samples}.")
sys.exit(1)
# 计算可用样本数和剩余需要分配的样本数
available = {dataset: len(lines) - 5 for dataset, lines in dataset_info.items()}
total_available = sum(available.values())
R = N - 5 * k
if R > total_available:
print(f"Error: Cannot allocate {R} samples from available {total_available}.")
sys.exit(1)
# 计算每个数据集分配的剩余样本数
allocations = []
sum_avail = total_available if total_available != 0 else 1 # 避免除以零
for dataset in dataset_order:
avail = available[dataset]
alloc_float = R * avail / sum_avail
allocations.append(alloc_float)
integer_part = [int(alloc) for alloc in allocations]
remainders = [alloc - int_part for alloc, int_part in zip(allocations, integer_part)]
remainder_total = R - sum(integer_part)
# 分配余数
remainder_indices = sorted(enumerate(remainders), key=lambda x: (-x[1], x[0]))
for i in range(remainder_total):
idx = remainder_indices[i][0]
integer_part[idx] += 1
# 计算每个数据集的最终采样数
sample_counts = {}
for i, dataset in enumerate(dataset_order):
alloc = integer_part[i]
if alloc > available[dataset]:
print(f"Error: Allocation for dataset '{dataset}' exceeds available samples.")
sys.exit(1)
sample_counts[dataset] = 5 + alloc
# 打印采样分布信息(新增部分)
print("\nSampling Distribution:")
total_sampled = 0
for dataset in dataset_order:
count = sample_counts[dataset]
total_sampled += count
print(f" - {dataset}: {count} samples")
print(f"Total samples: {total_sampled} (target: {N})")
# 验证总数正确性
if total_sampled != N:
print(f"Error: Total sampled count mismatch ({total_sampled} vs {N})")
sys.exit(1)
# 随机选择行号
selected_lines = []
for dataset in dataset_order:
lines = dataset_info[dataset]
count = sample_counts[dataset]
selected = random.sample(lines, count)
selected_lines.extend(selected)
selected_lines.sort()
# 写入输出文件
current_idx = 0
total_selected = len(selected_lines)
with open(input_file, 'r') as infile, open(output_file, 'w') as outfile:
for line_num, line in enumerate(infile, 1):
if current_idx >= total_selected:
break
if line_num == selected_lines[current_idx]:
outfile.write(line)
current_idx += 1
print(f"\nSuccessfully sampled {N} records to {output_file}.")
if __name__ == "__main__":
main() |